diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index b0edcf5a..835e5fd8 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -2,6 +2,8 @@ package handler import ( "context" + "crypto/hmac" + "crypto/sha256" "encoding/base64" "errors" "fmt" @@ -17,6 +19,7 @@ import ( infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -25,17 +28,24 @@ import ( ) const ( - linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" - linuxDoOAuthStateCookieName = "linuxdo_oauth_state" - linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" - linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" - linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes - linuxDoOAuthDefaultRedirectTo = "/dashboard" - linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" + linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" + oauthBindAccessTokenCookiePath = "/api/v1/auth/oauth" + linuxDoOAuthStateCookieName = "linuxdo_oauth_state" + linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" + linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" + linuxDoOAuthIntentCookieName = "linuxdo_oauth_intent" + linuxDoOAuthBindUserCookieName = "linuxdo_oauth_bind_user" + oauthBindAccessTokenCookieName = "oauth_bind_access_token" + linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + linuxDoOAuthDefaultRedirectTo = "/dashboard" + linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" linuxDoOAuthMaxRedirectLen = 2048 linuxDoOAuthMaxFragmentValueLen = 512 linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-") + + oauthIntentLogin = "login" + oauthIntentBindCurrentUser = "bind_current_user" ) type linuxDoTokenResponse struct { @@ -96,8 +106,20 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { secureCookie := isRequestHTTPS(c) setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) + intent := normalizeOAuthIntent(c.Query("intent")) + setCookie(c, linuxDoOAuthIntentCookieName, encodeCookieValue(intent), linuxDoOAuthCookieMaxAgeSec, secureCookie) setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) clearOAuthPendingSessionCookie(c, secureCookie) + if intent == oauthIntentBindCurrentUser { + bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c) + if err != nil { + response.ErrorFrom(c, err) + return + } + setCookie(c, linuxDoOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), linuxDoOAuthCookieMaxAgeSec, secureCookie) + } else { + clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie) + } verifier, err := oauth.GenerateCodeVerifier() if err != nil { @@ -153,6 +175,8 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) + clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie) + clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie) }() expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName) @@ -171,6 +195,8 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") return } + intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName) + intent = normalizeOAuthIntent(intent) codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie) if codeVerifier == "" { @@ -217,6 +243,40 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { if subject != "" { email = linuxDoSyntheticEmail(subject) } + if intent == oauthIntentBindCurrentUser { + targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName) + if err != nil { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "") + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentBindCurrentUser, + Identity: service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + }, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + }, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") @@ -784,6 +844,18 @@ func clearCookie(c *gin.Context, name string, secure bool) { }) } +func clearOAuthBindAccessTokenCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthBindAccessTokenCookieName, + Value: "", + Path: oauthBindAccessTokenCookiePath, + MaxAge: -1, + HttpOnly: false, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + func truncateFragmentValue(value string) string { value = strings.TrimSpace(value) if value == "" { @@ -842,3 +914,107 @@ func linuxDoSyntheticEmail(subject string) string { } return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain } + +func normalizeOAuthIntent(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", oauthIntentLogin: + return oauthIntentLogin + case "bind", oauthIntentBindCurrentUser: + return oauthIntentBindCurrentUser + default: + return oauthIntentLogin + } +} + +func (h *AuthHandler) buildOAuthBindUserCookieFromContext(c *gin.Context) (string, error) { + userID, err := h.resolveOAuthBindTargetUserID(c) + if err != nil || userID == nil || *userID <= 0 { + return "", infraerrors.Unauthorized("UNAUTHORIZED", "authentication required") + } + return buildOAuthBindUserCookieValue(*userID, h.oauthBindCookieSecret()) +} + +func (h *AuthHandler) resolveOAuthBindTargetUserID(c *gin.Context) (*int64, error) { + if subject, ok := servermiddleware.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { + return &subject.UserID, nil + } + if h == nil || h.authService == nil || h.userService == nil { + return nil, service.ErrInvalidToken + } + + ck, err := c.Request.Cookie(oauthBindAccessTokenCookieName) + clearOAuthBindAccessTokenCookie(c, isRequestHTTPS(c)) + if err != nil { + return nil, err + } + + tokenString, err := url.QueryUnescape(strings.TrimSpace(ck.Value)) + if err != nil { + return nil, err + } + if tokenString == "" { + return nil, service.ErrInvalidToken + } + + claims, err := h.authService.ValidateToken(tokenString) + if err != nil { + return nil, err + } + user, err := h.userService.GetByID(c.Request.Context(), claims.UserID) + if err != nil { + return nil, err + } + if user == nil || !user.IsActive() || claims.TokenVersion != user.TokenVersion { + return nil, service.ErrInvalidToken + } + return &user.ID, nil +} + +func (h *AuthHandler) readOAuthBindUserIDFromCookie(c *gin.Context, cookieName string) (int64, error) { + value, err := readCookieDecoded(c, cookieName) + if err != nil { + return 0, err + } + return parseOAuthBindUserCookieValue(value, h.oauthBindCookieSecret()) +} + +func (h *AuthHandler) oauthBindCookieSecret() string { + if h == nil || h.cfg == nil { + return "" + } + return strings.TrimSpace(h.cfg.JWT.Secret) +} + +func buildOAuthBindUserCookieValue(userID int64, secret string) (string, error) { + secret = strings.TrimSpace(secret) + if userID <= 0 || secret == "" { + return "", errors.New("invalid oauth bind cookie input") + } + payload := strconv.FormatInt(userID, 10) + mac := hmac.New(sha256.New, []byte(secret)) + _, _ = mac.Write([]byte(payload)) + signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + return payload + "." + signature, nil +} + +func parseOAuthBindUserCookieValue(value string, secret string) (int64, error) { + secret = strings.TrimSpace(secret) + if secret == "" { + return 0, errors.New("missing oauth bind cookie secret") + } + payload, signature, ok := strings.Cut(strings.TrimSpace(value), ".") + if !ok || payload == "" || signature == "" { + return 0, errors.New("invalid oauth bind cookie") + } + mac := hmac.New(sha256.New, []byte(secret)) + _, _ = mac.Write([]byte(payload)) + expectedSignature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + if !hmac.Equal([]byte(signature), []byte(expectedSignature)) { + return 0, errors.New("invalid oauth bind cookie signature") + } + userID, err := strconv.ParseInt(payload, 10, 64) + if err != nil || userID <= 0 { + return 0, errors.New("invalid oauth bind cookie user") + } + return userID, nil +} diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index 661c0da0..765779b5 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -9,11 +9,13 @@ import ( "testing" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -122,6 +124,321 @@ func TestSingleLineStripsWhitespace(t *testing.T) { require.Equal(t, "", singleLine("\n\t\r")) } +func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) { + handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: "https://connect.linux.do/oauth/authorize", + TokenURL: "https://connect.linux.do/oauth/token", + UserInfoURL: "https://connect.linux.do/api/user", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=/settings/connections", nil) + c.Request = req + c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 42}) + + handler.LinuxDoOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.Contains(t, location, "connect.linux.do/oauth/authorize") + require.Contains(t, location, "client_id=linuxdo-client") + require.Contains(t, location, "code_challenge=") + + cookies := recorder.Result().Cookies() + require.NotNil(t, findCookie(cookies, linuxDoOAuthStateCookieName)) + require.NotNil(t, findCookie(cookies, linuxDoOAuthRedirectCookie)) + require.NotNil(t, findCookie(cookies, linuxDoOAuthVerifierCookie)) + require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName)) + + intentCookie := findCookie(cookies, linuxDoOAuthIntentCookieName) + require.NotNil(t, intentCookie) + require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value)) + + bindCookie := findCookie(cookies, linuxDoOAuthBindUserCookieName) + require.NotNil(t, bindCookie) + userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret") + require.NoError(t, err) + require.Equal(t, int64(42), userID) +} + +func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) { + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: "https://connect.linux.do/oauth/authorize", + TokenURL: "https://connect.linux.do/oauth/token", + UserInfoURL: "https://connect.linux.do/api/user", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + defer client.Close() + + user, err := client.User.Create(). + SetEmail("bind-cookie@example.com"). + SetUsername("bind-cookie-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(context.Background()) + require.NoError(t, err) + + token, err := handler.authService.GenerateToken(&service.User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + PasswordHash: user.PasswordHash, + Role: user.Role, + Status: user.Status, + }) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=/settings/connections", nil) + req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: token, Path: oauthBindAccessTokenCookiePath}) + c.Request = req + + handler.LinuxDoOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + + bindCookie := findCookie(recorder.Result().Cookies(), linuxDoOAuthBindUserCookieName) + require.NotNil(t, bindCookie) + userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret") + require.NoError(t, err) + require.Equal(t, user.ID, userID) + + accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName) + require.NotNil(t, accessTokenCookie) + require.Equal(t, -1, accessTokenCookie.MaxAge) +} + +func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"321","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(linuxDoSyntheticEmail("321")). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-123&state=state-123", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-123")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, linuxDoSyntheticEmail("321"), session.ResolvedEmail) + require.Equal(t, "LinuxDo Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/dashboard", completion["redirect"]) + require.NotEmpty(t, completion["access_token"]) + require.Nil(t, completion["error"]) +} + +func TestLinuxDoOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvite(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_invite","name":"Need Invite","avatar_url":"https://cdn.example/invite.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, true, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-456&state=state-456", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-456")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-456")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + ctx := context.Background() + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Nil(t, session.TargetUserID) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "invitation_required", completion["error"]) + require.Equal(t, "/dashboard", completion["redirect"]) +} + +func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"999","username":"bind_user","name":"Bind Display","avatar_url":"https://cdn.example/bind.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + defer client.Close() + + ctx := context.Background() + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-bind&state=state-bind", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-bind")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/settings/connections")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-bind")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentBindCurrentUser)) + req.AddCookie(encodedCookie(linuxDoOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentBindCurrentUser, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, currentUser.ID, *session.TargetUserID) + require.Equal(t, linuxDoSyntheticEmail("999"), session.ResolvedEmail) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/settings/connections", completion["redirect"]) + require.Empty(t, completion["access_token"]) + require.Equal(t, "Bind Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, userCount) +} + func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() @@ -197,3 +514,25 @@ func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testi require.NoError(t, err) require.NotNil(t, consumed.ConsumedAt) } + +func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { + t.Helper() + handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) + return handler +} + +func newLinuxDoOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) (*AuthHandler, *dbent.Client) { + t.Helper() + handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled) + handler.settingSvc = nil + handler.cfg = &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + LinuxDo: oauthCfg, + } + return handler, client +} diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index da8ac858..2d6c3714 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -391,6 +391,16 @@ func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, sessio return create.Save(ctx) } +func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool { + if session == nil || decision == nil { + return false + } + if strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user") { + return true + } + return decision.AdoptDisplayName || decision.AdoptAvatar +} + func applyPendingOAuthAdoption( ctx context.Context, client *dbent.Client, @@ -401,7 +411,7 @@ func applyPendingOAuthAdoption( if client == nil || session == nil || decision == nil { return nil } - if !decision.AdoptDisplayName && !decision.AdoptAvatar { + if !shouldBindPendingOAuthIdentity(session, decision) { return nil } diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 829fc217..3afb4fb7 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -167,6 +167,348 @@ func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecisio require.NotNil(t, consumed.ConsumedAt) } +func TestExchangePendingOAuthCompletionBindCurrentUserPreviewThenFinalizeBindsIdentityWithoutAdoption(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("bind-target@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-pending-session-token"). + SetIntent("bind_current_user"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("bind-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("bind-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "Bound Example", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "redirect": "/settings/profile", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + previewRecorder := httptest.NewRecorder() + previewCtx, _ := gin.CreateTestContext(previewRecorder) + previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")}) + previewCtx.Request = previewReq + + handler.ExchangePendingOAuthCompletion(previewCtx) + + require.Equal(t, http.StatusOK, previewRecorder.Code) + previewData := decodeJSONResponseData(t, previewRecorder) + require.Equal(t, "Bound Example", previewData["suggested_display_name"]) + require.Equal(t, "https://cdn.example/bound.png", previewData["suggested_avatar_url"]) + require.Equal(t, true, previewData["adoption_required"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("bind-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + previewSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, previewSession.ConsumedAt) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + finalizeRecorder := httptest.NewRecorder() + finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder) + finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + finalizeReq.Header.Set("Content-Type", "application/json") + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")}) + finalizeCtx.Request = finalizeReq + + handler.ExchangePendingOAuthCompletion(finalizeCtx) + + require.Equal(t, http.StatusOK, finalizeRecorder.Code) + + storedUser, err := client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "legacy-name", storedUser.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("bind-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "Bound Example", identity.Metadata["suggested_display_name"]) + require.Equal(t, "https://cdn.example/bound.png", identity.Metadata["suggested_avatar_url"]) + _, hasDisplayName := identity.Metadata["display_name"] + require.False(t, hasDisplayName) + _, hasAvatarURL := identity.Metadata["avatar_url"] + require.False(t, hasAvatarURL) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + targetUser, err := client.User.Create(). + SetEmail("bind-conflict-target@example.com"). + SetUsername("target-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + ownerUser, err := client.User.Create(). + SetEmail("bind-conflict-owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + existingIdentity, err := client.AuthIdentity.Create(). + SetUserID(ownerUser.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("conflict-123"). + SetMetadata(map[string]any{"username": "owner-user"}). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-conflict-session-token"). + SetIntent("bind_current_user"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("conflict-123"). + SetTargetUserID(targetUser.ID). + SetResolvedEmail(targetUser.Email). + SetBrowserSessionKey("bind-conflict-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Conflict Example", + "suggested_avatar_url": "https://cdn.example/conflict.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-conflict-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusInternalServerError, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "PENDING_AUTH_ADOPTION_APPLY_FAILED", payload["reason"]) + + identity, err := client.AuthIdentity.Get(ctx, existingIdentity.ID) + require.NoError(t, err) + require.Equal(t, ownerUser.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("login-false@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("login-false-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("login-false-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("login-false-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Login Example", + "suggested_avatar_url": "https://cdn.example/login.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-false-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("login-false-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, true) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("invitation-required-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("invitation-123"). + SetBrowserSessionKey("invitation-required-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Invite Example", + "suggested_avatar_url": "https://cdn.example/invite.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "error": "invitation_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("invitation-required-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + data := decodeJSONResponseData(t, recorder) + require.Equal(t, "invitation_required", data["error"]) + require.Equal(t, true, data["adoption_required"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("invitation-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper() @@ -198,9 +540,10 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), }, }, cfg) + userRepo := &oauthPendingFlowUserRepo{client: client} authSvc := service.NewAuthService( client, - &oauthPendingFlowUserRepo{client: client}, + userRepo, nil, &oauthPendingFlowRefreshTokenCacheStub{}, cfg, @@ -211,9 +554,11 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth nil, nil, ) + userSvc := service.NewUserService(userRepo, nil, nil, nil) return &AuthHandler{ authService: authSvc, + userService: userSvc, settingSvc: settingSvc, }, client } @@ -414,7 +759,7 @@ func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error { } func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { - return nil, service.ErrUserNotFound + return nil, nil } func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) { @@ -462,6 +807,33 @@ func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Cont panic("unexpected RemoveGroupFromUserAllowedGroups call") } +func (r *oauthPendingFlowUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + identities, err := r.client.AuthIdentity.Query(). + Where(authidentity.UserIDEQ(userID)). + All(ctx) + if err != nil { + return nil, err + } + + records := make([]service.UserAuthIdentityRecord, 0, len(identities)) + for _, identity := range identities { + if identity == nil { + continue + } + records = append(records, service.UserAuthIdentityRecord{ + ProviderType: identity.ProviderType, + ProviderKey: identity.ProviderKey, + ProviderSubject: identity.ProviderSubject, + VerifiedAt: identity.VerifiedAt, + Issuer: identity.Issuer, + Metadata: identity.Metadata, + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + }) + } + return records, nil +} + func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected UpdateTotpSecret call") } diff --git a/backend/internal/handler/auth_oauth_test_helpers_test.go b/backend/internal/handler/auth_oauth_test_helpers_test.go new file mode 100644 index 00000000..8eb87dbb --- /dev/null +++ b/backend/internal/handler/auth_oauth_test_helpers_test.go @@ -0,0 +1,39 @@ +package handler + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func buildEncodedOAuthBindUserCookie(t *testing.T, userID int64, secret string) string { + t.Helper() + value, err := buildOAuthBindUserCookieValue(userID, secret) + require.NoError(t, err) + return value +} + +func encodedCookie(name, value string) *http.Cookie { + return &http.Cookie{ + Name: name, + Value: encodeCookieValue(value), + Path: "/", + } +} + +func findCookie(cookies []*http.Cookie, name string) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + return nil +} + +func decodeCookieValueForTest(t *testing.T, value string) string { + t.Helper() + decoded, err := decodeCookieValue(value) + require.NoError(t, err) + return decoded +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index ceda633c..0f79759e 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -32,14 +32,16 @@ import ( ) const ( - oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc" - oidcOAuthStateCookieName = "oidc_oauth_state" - oidcOAuthVerifierCookie = "oidc_oauth_verifier" - oidcOAuthRedirectCookie = "oidc_oauth_redirect" - oidcOAuthNonceCookie = "oidc_oauth_nonce" - oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes - oidcOAuthDefaultRedirectTo = "/dashboard" - oidcOAuthDefaultFrontendCB = "/auth/oidc/callback" + oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc" + oidcOAuthStateCookieName = "oidc_oauth_state" + oidcOAuthVerifierCookie = "oidc_oauth_verifier" + oidcOAuthRedirectCookie = "oidc_oauth_redirect" + oidcOAuthNonceCookie = "oidc_oauth_nonce" + oidcOAuthIntentCookieName = "oidc_oauth_intent" + oidcOAuthBindUserCookieName = "oidc_oauth_bind_user" + oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + oidcOAuthDefaultRedirectTo = "/dashboard" + oidcOAuthDefaultFrontendCB = "/auth/oidc/callback" ) type oidcTokenResponse struct { @@ -138,8 +140,20 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) { secureCookie := isRequestHTTPS(c) oidcSetCookie(c, oidcOAuthStateCookieName, encodeCookieValue(state), oidcOAuthCookieMaxAgeSec, secureCookie) oidcSetCookie(c, oidcOAuthRedirectCookie, encodeCookieValue(redirectTo), oidcOAuthCookieMaxAgeSec, secureCookie) + intent := normalizeOAuthIntent(c.Query("intent")) + oidcSetCookie(c, oidcOAuthIntentCookieName, encodeCookieValue(intent), oidcOAuthCookieMaxAgeSec, secureCookie) setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) clearOAuthPendingSessionCookie(c, secureCookie) + if intent == oauthIntentBindCurrentUser { + bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c) + if err != nil { + response.ErrorFrom(c, err) + return + } + oidcSetCookie(c, oidcOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), oidcOAuthCookieMaxAgeSec, secureCookie) + } else { + oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie) + } codeChallenge := "" verifier, genErr := oauth.GenerateCodeVerifier() @@ -205,6 +219,8 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie) oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie) oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie) + oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie) + oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie) }() expectedState, err := readCookieDecoded(c, oidcOAuthStateCookieName) @@ -223,6 +239,8 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") return } + intent, _ := readCookieDecoded(c, oidcOAuthIntentCookieName) + intent = normalizeOAuthIntent(intent) codeVerifier := "" codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie) @@ -324,6 +342,43 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { idClaims.Name, oidcFallbackUsername(subject), ) + if intent == oauthIntentBindCurrentUser { + targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName) + if err != nil { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "") + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentBindCurrentUser, + Identity: service.PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: issuer, + ProviderSubject: subject, + }, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: map[string]any{ + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_avatar_url": userInfoClaims.AvatarURL, + }, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index 9107e13a..07f5ef68 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -13,11 +13,13 @@ import ( "testing" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" @@ -131,6 +133,227 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { } } +func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) { + handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{ + Enabled: true, + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/oauth/authorize", + TokenURL: "https://issuer.example.com/oauth/token", + UserInfoURL: "https://issuer.example.com/oauth/userinfo", + JWKSURL: "https://issuer.example.com/oauth/jwks", + Scopes: "openid profile email", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + RequireEmailVerified: false, + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=/settings/connections", nil) + c.Request = req + c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 84}) + + handler.OIDCOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.Contains(t, location, "issuer.example.com/oauth/authorize") + require.Contains(t, location, "client_id=oidc-client") + require.Contains(t, location, "nonce=") + + cookies := recorder.Result().Cookies() + require.NotNil(t, findCookie(cookies, oidcOAuthStateCookieName)) + require.NotNil(t, findCookie(cookies, oidcOAuthRedirectCookie)) + require.NotNil(t, findCookie(cookies, oidcOAuthVerifierCookie)) + require.NotNil(t, findCookie(cookies, oidcOAuthNonceCookie)) + require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName)) + + intentCookie := findCookie(cookies, oidcOAuthIntentCookieName) + require.NotNil(t, intentCookie) + require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value)) + + bindCookie := findCookie(cookies, oidcOAuthBindUserCookieName) + require.NotNil(t, bindCookie) + userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret") + require.NoError(t, err) + require.Equal(t, int64(84), userID) +} + +func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-login", + PreferredUsername: "oidc_login", + DisplayName: "OIDC Login Display", + AvatarURL: "https://cdn.example/oidc-login.png", + Email: "oidc-login@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-subject-login"))). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-123")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-login")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, cfg.IssuerURL, session.ProviderKey) + require.Equal(t, "OIDC Login Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/dashboard", completion["redirect"]) + require.NotEmpty(t, completion["access_token"]) + require.Nil(t, completion["error"]) +} + +func TestOIDCOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvite(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-invite", + PreferredUsername: "oidc_invite", + DisplayName: "OIDC Invite Display", + AvatarURL: "https://cdn.example/oidc-invite.png", + Email: "oidc-invite@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, true, cfg) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-456", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-456")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-456")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-invite")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + ctx := context.Background() + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Nil(t, session.TargetUserID) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "invitation_required", completion["error"]) + require.Equal(t, "/dashboard", completion["redirect"]) +} + +func TestOIDCOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-bind", + PreferredUsername: "oidc_bind", + DisplayName: "OIDC Bind Display", + AvatarURL: "https://cdn.example/oidc-bind.png", + Email: "oidc-bind@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + defer client.Close() + + ctx := context.Background() + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-bind", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-bind")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-bind")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-bind")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentBindCurrentUser)) + req.AddCookie(encodedCookie(oidcOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentBindCurrentUser, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, currentUser.ID, *session.TargetUserID) + require.Equal(t, cfg.IssuerURL, session.ProviderKey) + require.Equal(t, "OIDC Bind Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/settings/connections", completion["redirect"]) + require.Empty(t, completion["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, userCount) +} + func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() @@ -207,3 +430,116 @@ func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing. require.NoError(t, err) require.NotNil(t, consumed.ConsumedAt) } + +type oidcProviderFixture struct { + Subject string + PreferredUsername string + DisplayName string + AvatarURL string + Email string + EmailVerified bool +} + +func newOIDCOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) *AuthHandler { + t.Helper() + handler, _ := newOIDCOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) + return handler +} + +func newOIDCOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) (*AuthHandler, *dbent.Client) { + t.Helper() + handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled) + handler.settingSvc = nil + handler.cfg = &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + OIDC: oauthCfg, + } + return handler, client +} + +func newOIDCTestProvider(t *testing.T, fixture oidcProviderFixture) (config.OIDCConnectConfig, func()) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + kid := "test-kid" + jwks := oidcJWKSet{Keys: []oidcJWK{buildRSAJWK(kid, &privateKey.PublicKey)}} + tokenResponse := oidcTokenResponse{ + AccessToken: "oidc-access-token", + TokenType: "Bearer", + ExpiresIn: 3600, + } + + userInfoPayload := map[string]any{ + "sub": fixture.Subject, + "preferred_username": fixture.PreferredUsername, + "name": fixture.DisplayName, + "picture": fixture.AvatarURL, + "email": fixture.Email, + "email_verified": fixture.EmailVerified, + } + + var issuer string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + require.NoError(t, json.NewEncoder(w).Encode(tokenResponse)) + case "/userinfo": + require.NoError(t, json.NewEncoder(w).Encode(userInfoPayload)) + case "/jwks": + require.NoError(t, json.NewEncoder(w).Encode(jwks)) + default: + http.NotFound(w, r) + } + })) + + issuer = server.URL + now := time.Now() + claims := oidcIDTokenClaims{ + Email: fixture.Email, + EmailVerified: boolPtr(fixture.EmailVerified), + PreferredUsername: fixture.PreferredUsername, + Name: fixture.DisplayName, + Nonce: "nonce-" + fixture.Subject, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: issuer, + Subject: fixture.Subject, + Audience: jwt.ClaimStrings{"oidc-client"}, + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now.Add(-30 * time.Second)), + ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = kid + tokenResponse.IDToken, err = token.SignedString(privateKey) + require.NoError(t, err) + + cfg := config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "Test OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: issuer, + AuthorizeURL: issuer + "/authorize", + TokenURL: issuer + "/token", + UserInfoURL: issuer + "/userinfo", + JWKSURL: issuer + "/jwks", + Scopes: "openid profile email", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + RequireEmailVerified: false, + } + return cfg, server.Close +} diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 867a77a1..45ac6cad 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -12,6 +12,9 @@ import ( "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -27,6 +30,7 @@ const ( wechatOAuthRedirectCookieName = "wechat_oauth_redirect" wechatOAuthIntentCookieName = "wechat_oauth_intent" wechatOAuthModeCookieName = "wechat_oauth_mode" + wechatOAuthBindUserCookieName = "wechat_oauth_bind_user" wechatOAuthDefaultRedirectTo = "/dashboard" wechatOAuthDefaultFrontendCB = "/auth/wechat/callback" wechatOAuthProviderKey = "wechat-main" @@ -105,6 +109,16 @@ func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) { wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie) setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) clearOAuthPendingSessionCookie(c, secureCookie) + if intent == oauthIntentBindCurrentUser { + bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c) + if err != nil { + response.ErrorFrom(c, err) + return + } + wechatSetCookie(c, wechatOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), wechatOAuthCookieMaxAgeSec, secureCookie) + } else { + wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie) + } authURL, err := buildWeChatAuthorizeURL(cfg, state) if err != nil { @@ -138,6 +152,7 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie) wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie) wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie) }() expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName) @@ -193,13 +208,33 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { "openid": openid, "unionid": unionid, "mode": cfg.mode, + "channel": cfg.mode, + "channel_app_id": strings.TrimSpace(cfg.appID), + "channel_subject": openid, "suggested_display_name": strings.TrimSpace(userInfo.Nickname), "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL), } + normalizedIntent := normalizeWeChatOAuthIntent(intent) + if normalizedIntent == wechatOAuthIntentBind { + if err := h.createWeChatBindPendingSession(c, cfg, providerSubject, openid, redirectTo, browserSessionKey, upstreamClaims); err != nil { + switch infraerrors.Code(err) { + case http.StatusConflict: + redirectOAuthError(c, frontendCallback, "ownership_conflict", infraerrors.Reason(err), infraerrors.Message(err)) + case http.StatusUnauthorized, http.StatusForbidden: + redirectOAuthError(c, frontendCallback, "auth_required", infraerrors.Reason(err), infraerrors.Message(err)) + default: + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + } + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { - if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err); err != nil { + if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err, nil); err != nil { redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") return } @@ -207,7 +242,7 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { return } - if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil); err != nil { + if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, nil); err != nil { redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") return } @@ -309,6 +344,7 @@ func (h *AuthHandler) createWeChatPendingSession( upstreamClaims map[string]any, tokenPair *service.TokenPair, authErr error, + targetUserID *int64, ) error { completionResponse := map[string]any{ "redirect": redirectTo, @@ -333,6 +369,7 @@ func (h *AuthHandler) createWeChatPendingSession( ProviderKey: wechatOAuthProviderKey, ProviderSubject: providerSubject, }, + TargetUserID: targetUserID, ResolvedEmail: email, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, @@ -341,6 +378,106 @@ func (h *AuthHandler) createWeChatPendingSession( }) } +func (h *AuthHandler) createWeChatBindPendingSession( + c *gin.Context, + cfg wechatOAuthConfig, + providerSubject string, + channelSubject string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, +) error { + currentUser, err := h.readOAuthBindTargetUser(c, wechatOAuthBindUserCookieName) + if err != nil { + return err + } + if err := h.ensureWeChatBindOwnership(c.Request.Context(), currentUser.ID, providerSubject, cfg, channelSubject); err != nil { + return err + } + return h.createWeChatPendingSession( + c, + wechatOAuthIntentBind, + providerSubject, + currentUser.Email, + redirectTo, + browserSessionKey, + upstreamClaims, + nil, + nil, + ¤tUser.ID, + ) +} + +func (h *AuthHandler) readOAuthBindTargetUser(c *gin.Context, cookieName string) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + userID, err := h.readOAuthBindUserIDFromCookie(c, cookieName) + if err != nil { + return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account") + } + userEntity, err := client.User.Get(c.Request.Context(), userID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account") + } + return nil, infraerrors.InternalServer("WECHAT_BIND_USER_LOOKUP_FAILED", "failed to load current user").WithCause(err) + } + return userEntity, nil +} + +func (h *AuthHandler) ensureWeChatBindOwnership( + ctx context.Context, + userID int64, + providerSubject string, + cfg wechatOAuthConfig, + channelSubject string, +) error { + client := h.entClient() + if client == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ(strings.TrimSpace(providerSubject)), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return infraerrors.InternalServer("WECHAT_BIND_LOOKUP_FAILED", "failed to inspect wechat identity ownership").WithCause(err) + } + if identity != nil && identity.UserID != userID { + return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + + channelSubject = strings.TrimSpace(channelSubject) + channelAppID := strings.TrimSpace(cfg.appID) + if channelSubject == "" || channelAppID == "" { + return nil + } + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ChannelEQ(strings.TrimSpace(cfg.mode)), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(channelSubject), + ). + WithIdentity(). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return infraerrors.InternalServer("WECHAT_BIND_CHANNEL_LOOKUP_FAILED", "failed to inspect wechat identity channel ownership").WithCause(err) + } + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + return nil +} + func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) { mode, err := resolveWeChatOAuthMode(rawMode, c) if err != nil { diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 1a765dcc..0d1df1b6 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -9,6 +9,7 @@ import ( "encoding/base64" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" @@ -121,6 +122,298 @@ func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"]) } +func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) { + testCases := []struct { + name string + mode string + appIDEnv string + appID string + appSecret string + openID string + }{ + { + name: "open", + mode: "open", + appIDEnv: "WECHAT_OAUTH_OPEN_APP_ID", + appID: "wx-open-app", + appSecret: "wx-open-secret", + openID: "openid-open-123", + }, + { + name: "mp", + mode: "mp", + appIDEnv: "WECHAT_OAUTH_MP_APP_ID", + appID: "wx-mp-app", + appSecret: "wx-mp-secret", + openID: "openid-mp-123", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(tc.appIDEnv, tc.appID) + switch tc.mode { + case "open": + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", tc.appSecret) + case "mp": + t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", tc.appSecret) + } + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"` + tc.openID + `","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"` + tc.openID + `","unionid":"union-456","nickname":"Bind Nick","headimgurl":"https://cdn.example/bind.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(context.Background()) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, tc.mode)) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(context.Background()) + require.NoError(t, err) + require.Equal(t, wechatOAuthIntentBind, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, currentUser.ID, *session.TargetUserID) + require.Equal(t, currentUser.Email, session.ResolvedEmail) + require.Equal(t, "union-456", session.ProviderSubject) + require.Equal(t, "union-456", session.UpstreamIdentityClaims["subject"]) + require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"]) + require.Equal(t, tc.openID, session.UpstreamIdentityClaims["openid"]) + require.Equal(t, tc.mode, session.UpstreamIdentityClaims["channel"]) + require.Equal(t, tc.appID, session.UpstreamIdentityClaims["channel_app_id"]) + require.Equal(t, tc.openID, session.UpstreamIdentityClaims["channel_subject"]) + + completionResponse := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/dashboard", completionResponse["redirect"]) + _, hasAccessToken := completionResponse["access_token"] + require.False(t, hasAccessToken) + }) + } +} + +func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + ownerIdentity, err := client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("union-owner"). + SetMetadata(map[string]any{"unionid": "union-owner"}). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentityChannel.Create(). + SetIdentityID(ownerIdentity.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetChannel("open"). + SetChannelAppID("wx-open-app"). + SetChannelSubject("openid-123"). + SetMetadata(map[string]any{"openid": "openid-123"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") @@ -322,6 +615,18 @@ func decodeCookieValueForTest(t *testing.T, value string) string { return string(raw) } +func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) { + t.Helper() + + parsed, err := url.Parse(location) + require.NoError(t, err) + + fragment, err := url.ParseQuery(parsed.Fragment) + require.NoError(t, err) + require.Equal(t, errorCode, fragment.Get("error")) + require.Equal(t, errorMessage, fragment.Get("error_message")) +} + type wechatOAuthSettingRepoStub struct { values map[string]string } diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 904341d0..9dcff828 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -1,6 +1,8 @@ package handler import ( + "context" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -41,7 +43,8 @@ type UpdateProfileRequest struct { type userProfileResponse struct { dto.User - AvatarURL string `json:"avatar_url,omitempty"` + AvatarURL string `json:"avatar_url,omitempty"` + Identities service.UserIdentitySummarySet `json:"identities"` } // GetProfile handles getting user profile @@ -59,7 +62,13 @@ func (h *UserHandler) GetProfile(c *gin.Context) { return } - response.Success(c, userProfileResponseFromService(userData)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, userData) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) } // ChangePassword handles changing user password @@ -117,7 +126,44 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { return } - response.Success(c, userProfileResponseFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) +} + +type StartIdentityBindingRequest struct { + Provider string `json:"provider" binding:"required"` + RedirectTo string `json:"redirect_to"` +} + +// StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow. +// POST /api/v1/user/auth-identities/bind/start +func (h *UserHandler) StartIdentityBinding(c *gin.Context) { + if _, ok := middleware2.GetAuthSubjectFromContext(c); !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + var req StartIdentityBindingRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + result, err := h.userService.PrepareIdentityBindingStart(c.Request.Context(), service.StartUserIdentityBindingRequest{ + Provider: req.Provider, + RedirectTo: req.RedirectTo, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, result) } // SendNotifyEmailCodeRequest represents the request to send notify email verification code @@ -183,7 +229,13 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) { return } - response.Success(c, userProfileResponseFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) } // RemoveNotifyEmailRequest represents the request to remove a notify email @@ -219,7 +271,13 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) { return } - response.Success(c, userProfileResponseFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) } // ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state @@ -255,16 +313,31 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) { return } - response.Success(c, userProfileResponseFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) } -func userProfileResponseFromService(user *service.User) userProfileResponse { +func (h *UserHandler) buildUserProfileResponse(ctx context.Context, userID int64, user *service.User) (userProfileResponse, error) { + identities, err := h.userService.GetProfileIdentitySummaries(ctx, userID, user) + if err != nil { + return userProfileResponse{}, err + } + return userProfileResponseFromService(user, identities), nil +} + +func userProfileResponseFromService(user *service.User, identities service.UserIdentitySummarySet) userProfileResponse { base := dto.UserFromService(user) if base == nil { return userProfileResponse{} } return userProfileResponse{ - User: *base, - AvatarURL: user.AvatarURL, + User: *base, + AvatarURL: user.AvatarURL, + Identities: identities, } } diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index 1973f59e..b71846c1 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -18,7 +19,8 @@ import ( ) type userHandlerRepoStub struct { - user *service.User + user *service.User + identities []service.UserAuthIdentityRecord } func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil } @@ -96,6 +98,11 @@ func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil } func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) { + out := make([]service.UserAuthIdentityRecord, len(s.identities)) + copy(out, s.identities) + return out, nil +} func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { gin.SetMode(gin.TestMode) @@ -134,3 +141,135 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL) require.Equal(t, "handler-avatar", resp.Data.Username) } + +func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) { + gin.SetMode(gin.TestMode) + + verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC) + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-123456", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + { + ProviderType: "oidc", + ProviderKey: "https://issuer.example.com", + ProviderSubject: "oidc-user-abc", + Metadata: map[string]any{ + "suggested_display_name": "OIDC Display", + }, + }, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.GetProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Identities struct { + Email struct { + Bound bool `json:"bound"` + BoundCount int `json:"bound_count"` + DisplayName string `json:"display_name"` + } `json:"email"` + LinuxDo struct { + Bound bool `json:"bound"` + BoundCount int `json:"bound_count"` + DisplayName string `json:"display_name"` + ProviderKey string `json:"provider_key"` + } `json:"linuxdo"` + OIDC struct { + Bound bool `json:"bound"` + DisplayName string `json:"display_name"` + ProviderKey string `json:"provider_key"` + } `json:"oidc"` + WeChat struct { + Bound bool `json:"bound"` + CanBind bool `json:"can_bind"` + BindStartPath string `json:"bind_start_path"` + } `json:"wechat"` + } `json:"identities"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.True(t, resp.Data.Identities.Email.Bound) + require.Equal(t, 1, resp.Data.Identities.Email.BoundCount) + require.Equal(t, "identity@example.com", resp.Data.Identities.Email.DisplayName) + require.True(t, resp.Data.Identities.LinuxDo.Bound) + require.Equal(t, 1, resp.Data.Identities.LinuxDo.BoundCount) + require.Equal(t, "linuxdo-handle", resp.Data.Identities.LinuxDo.DisplayName) + require.Equal(t, "linuxdo", resp.Data.Identities.LinuxDo.ProviderKey) + require.True(t, resp.Data.Identities.OIDC.Bound) + require.Equal(t, "OIDC Display", resp.Data.Identities.OIDC.DisplayName) + require.Equal(t, "https://issuer.example.com", resp.Data.Identities.OIDC.ProviderKey) + require.False(t, resp.Data.Identities.WeChat.Bound) + require.True(t, resp.Data.Identities.WeChat.CanBind) + require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/start") +} + +func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) + + body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/auth-identities/bind/start", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.StartIdentityBinding(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Provider string `json:"provider"` + AuthorizeURL string `json:"authorize_url"` + Method string `json:"method"` + UseBrowserRedirect bool `json:"use_browser_redirect"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "wechat", resp.Data.Provider) + require.Equal(t, "GET", resp.Data.Method) + require.True(t, resp.Data.UseBrowserRedirect) + require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/start") + require.Contains(t, resp.Data.AuthorizeURL, "intent=bind_current_user") + require.Contains(t, resp.Data.AuthorizeURL, "redirect=%2Fsettings%2Fprofile") +} diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go index 4ecae4a4..c1b4b6bf 100644 --- a/backend/internal/repository/user_profile_identity_repo.go +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -211,6 +211,34 @@ func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthI }, nil } +func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + identities, err := clientFromContext(ctx, r.client).AuthIdentity.Query(). + Where(authidentity.UserIDEQ(userID)). + All(ctx) + if err != nil { + return nil, err + } + + records := make([]service.UserAuthIdentityRecord, 0, len(identities)) + for _, identity := range identities { + if identity == nil { + continue + } + records = append(records, service.UserAuthIdentityRecord{ + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + VerifiedAt: identity.VerifiedAt, + Issuer: identity.Issuer, + Metadata: copyMetadata(identity.Metadata), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + }) + } + + return records, nil +} + func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) { var result *CreateAuthIdentityResult err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 911a4064..7a34834d 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -108,5 +108,23 @@ func RegisterAuthRoutes( authenticated.GET("/auth/me", h.Auth.GetCurrentUser) // 撤销所有会话(需要认证) authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions) + authenticated.GET("/auth/oauth/linuxdo/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.LinuxDoOAuthStart(c) + }) + authenticated.GET("/auth/oauth/oidc/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.OIDCOAuthStart(c) + }) + authenticated.GET("/auth/oauth/wechat/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.WeChatOAuthStart(c) + }) } } diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index d004f8b4..ccbe23ce 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -25,6 +25,7 @@ func RegisterUserRoutes( user.GET("/profile", h.User.GetProfile) user.PUT("/password", h.User.ChangePassword) user.PUT("", h.User.UpdateProfile) + user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding) // 通知邮箱管理 notifyEmail := user.Group("/notify-email") diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 2f6d9427..c52f91bb 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -7,13 +7,13 @@ import ( "encoding/base64" "encoding/hex" "fmt" - "log/slog" - "net/url" - "strings" - "time" - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "log/slog" + "net/url" + "sort" + "strings" + "time" ) var ( @@ -24,6 +24,8 @@ var ( ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL") ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller") ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image") + ErrIdentityProviderInvalid = infraerrors.BadRequest("IDENTITY_PROVIDER_INVALID", "identity provider is invalid") + ErrIdentityRedirectInvalid = infraerrors.BadRequest("IDENTITY_REDIRECT_INVALID", "identity redirect path is invalid") ) const ( @@ -33,6 +35,8 @@ const ( // User-level rate limiting for notify email verification codes notifyCodeUserRateLimit = 5 notifyCodeUserRateWindow = 10 * time.Minute + + defaultUserIdentityRedirect = "/settings/profile" ) // UserListFilters contains all filter options for listing users @@ -71,6 +75,7 @@ type UserRepository interface { AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error // RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限 RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error + ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) // TOTP 双因素认证 UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error @@ -78,6 +83,50 @@ type UserRepository interface { DisableTotp(ctx context.Context, userID int64) error } +type UserAuthIdentityRecord struct { + ProviderType string + ProviderKey string + ProviderSubject string + VerifiedAt *time.Time + Issuer *string + Metadata map[string]any + CreatedAt time.Time + UpdatedAt time.Time +} + +type UserIdentitySummary struct { + Provider string `json:"provider"` + Bound bool `json:"bound"` + BoundCount int `json:"bound_count"` + DisplayName string `json:"display_name,omitempty"` + SubjectHint string `json:"subject_hint,omitempty"` + ProviderKey string `json:"provider_key,omitempty"` + VerifiedAt *time.Time `json:"verified_at,omitempty"` + BindStartPath string `json:"bind_start_path,omitempty"` + CanBind bool `json:"can_bind"` + CanUnbind bool `json:"can_unbind"` + Note string `json:"note,omitempty"` +} + +type UserIdentitySummarySet struct { + Email UserIdentitySummary `json:"email"` + LinuxDo UserIdentitySummary `json:"linuxdo"` + OIDC UserIdentitySummary `json:"oidc"` + WeChat UserIdentitySummary `json:"wechat"` +} + +type StartUserIdentityBindingRequest struct { + Provider string + RedirectTo string +} + +type StartUserIdentityBindingResult struct { + Provider string `json:"provider"` + AuthorizeURL string `json:"authorize_url"` + Method string `json:"method"` + UseBrowserRedirect bool `json:"use_browser_redirect"` +} + // UpdateProfileRequest 更新用户资料请求 type UpdateProfileRequest struct { Email *string `json:"email"` @@ -106,6 +155,10 @@ type UpsertUserAvatarInput struct { SHA256 string } +type userAuthIdentityReader interface { + ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) +} + // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { CurrentPassword string `json:"current_password"` @@ -151,6 +204,47 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro return user, nil } +func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID int64, user *User) (UserIdentitySummarySet, error) { + if user == nil { + var err error + user, err = s.userRepo.GetByID(ctx, userID) + if err != nil { + return UserIdentitySummarySet{}, fmt.Errorf("get user: %w", err) + } + } + + records, err := s.listUserAuthIdentities(ctx, userID) + if err != nil { + return UserIdentitySummarySet{}, err + } + + return UserIdentitySummarySet{ + Email: s.buildEmailIdentitySummary(user), + LinuxDo: s.buildProviderIdentitySummary("linuxdo", records), + OIDC: s.buildProviderIdentitySummary("oidc", records), + WeChat: s.buildProviderIdentitySummary("wechat", records), + }, nil +} + +func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUserIdentityBindingRequest) (*StartUserIdentityBindingResult, error) { + provider := normalizeUserIdentityProvider(req.Provider) + if provider == "" { + return nil, ErrIdentityProviderInvalid + } + + authorizeURL, err := buildUserIdentityBindAuthorizeURL(provider, req.RedirectTo) + if err != nil { + return nil, err + } + + return &StartUserIdentityBindingResult{ + Provider: provider, + AuthorizeURL: authorizeURL, + Method: "GET", + UseBrowserRedirect: true, + }, nil +} + // UpdateProfile 更新用户资料 func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) { user, err := s.userRepo.GetByID(ctx, userID) @@ -303,6 +397,234 @@ func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { }, nil } +func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary { + summary := UserIdentitySummary{ + Provider: "email", + CanBind: false, + CanUnbind: false, + Note: "Primary account email is managed from the profile form.", + } + if user == nil { + return summary + } + + email := strings.TrimSpace(user.Email) + if email == "" || isReservedEmail(email) { + return summary + } + + summary.Bound = true + summary.BoundCount = 1 + summary.DisplayName = email + summary.SubjectHint = maskEmailIdentity(email) + summary.ProviderKey = "email" + return summary +} + +func (s *UserService) buildProviderIdentitySummary(provider string, records []UserAuthIdentityRecord) UserIdentitySummary { + summary := UserIdentitySummary{ + Provider: provider, + CanUnbind: false, + } + filtered := filterUserAuthIdentities(records, provider) + if len(filtered) == 0 { + summary.CanBind = true + bindStartPath, err := buildUserIdentityBindAuthorizeURL(provider, "") + if err == nil { + summary.BindStartPath = bindStartPath + } + return summary + } + + primary := selectPrimaryUserAuthIdentity(filtered) + summary.Bound = true + summary.BoundCount = len(filtered) + summary.DisplayName = userAuthIdentityDisplayName(primary) + summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject) + summary.ProviderKey = strings.TrimSpace(primary.ProviderKey) + summary.VerifiedAt = primary.VerifiedAt + summary.Note = "Unbind is not available yet." + return summary +} + +func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) { + if userID <= 0 || s == nil || s.userRepo == nil { + return nil, nil + } + return s.userRepo.ListUserAuthIdentities(ctx, userID) +} + +func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, error) { + provider = normalizeUserIdentityProvider(provider) + if provider == "" || provider == "email" { + return "", ErrIdentityProviderInvalid + } + + redirectTo, err := normalizeUserIdentityRedirect(redirectTo) + if err != nil { + return "", err + } + + path := "" + switch provider { + case "linuxdo": + path = "/api/v1/auth/oauth/linuxdo/start" + case "oidc": + path = "/api/v1/auth/oauth/oidc/start" + case "wechat": + path = "/api/v1/auth/oauth/wechat/start" + default: + return "", ErrIdentityProviderInvalid + } + + query := url.Values{} + query.Set("redirect", redirectTo) + query.Set("intent", "bind_current_user") + return path + "?" + query.Encode(), nil +} + +func normalizeUserIdentityProvider(provider string) string { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "linuxdo": + return "linuxdo" + case "oidc": + return "oidc" + case "wechat": + return "wechat" + case "email": + return "email" + default: + return "" + } +} + +func normalizeUserIdentityRedirect(raw string) (string, error) { + redirect := strings.TrimSpace(raw) + if redirect == "" { + return defaultUserIdentityRedirect, nil + } + if len(redirect) > 2048 || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { + return "", ErrIdentityRedirectInvalid + } + return redirect, nil +} + +func filterUserAuthIdentities(records []UserAuthIdentityRecord, provider string) []UserAuthIdentityRecord { + if len(records) == 0 { + return nil + } + filtered := make([]UserAuthIdentityRecord, 0, len(records)) + for _, record := range records { + if strings.EqualFold(strings.TrimSpace(record.ProviderType), provider) { + filtered = append(filtered, record) + } + } + return filtered +} + +func selectPrimaryUserAuthIdentity(records []UserAuthIdentityRecord) UserAuthIdentityRecord { + if len(records) == 0 { + return UserAuthIdentityRecord{} + } + sort.SliceStable(records, func(i, j int) bool { + left := userAuthIdentitySortTime(records[i]) + right := userAuthIdentitySortTime(records[j]) + if !left.Equal(right) { + return left.After(right) + } + return records[i].ProviderKey < records[j].ProviderKey + }) + return records[0] +} + +func userAuthIdentitySortTime(record UserAuthIdentityRecord) time.Time { + if record.VerifiedAt != nil && !record.VerifiedAt.IsZero() { + return record.VerifiedAt.UTC() + } + if !record.UpdatedAt.IsZero() { + return record.UpdatedAt.UTC() + } + if !record.CreatedAt.IsZero() { + return record.CreatedAt.UTC() + } + return time.Time{} +} + +func userAuthIdentityDisplayName(record UserAuthIdentityRecord) string { + if displayName := firstStringIdentityValue(record.Metadata, + "display_name", + "suggested_display_name", + "username", + "name", + "nickname", + "email", + ); displayName != "" { + return displayName + } + if subject := strings.TrimSpace(record.ProviderSubject); subject != "" { + return subject + } + return strings.TrimSpace(record.ProviderType) +} + +func firstStringIdentityValue(values map[string]any, keys ...string) string { + for _, key := range keys { + raw, ok := values[key] + if !ok { + continue + } + switch value := raw.(type) { + case string: + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + case fmt.Stringer: + if trimmed := strings.TrimSpace(value.String()); trimmed != "" { + return trimmed + } + } + } + return "" +} + +func maskEmailIdentity(email string) string { + local, domain, ok := strings.Cut(strings.TrimSpace(email), "@") + if !ok || local == "" || domain == "" { + return maskOpaqueIdentity(email) + } + runes := []rune(local) + if len(runes) == 1 { + return string(runes[0]) + "***@" + domain + } + return string(runes[0]) + "***" + string(runes[len(runes)-1]) + "@" + domain +} + +func maskOpaqueIdentity(value string) string { + value = strings.TrimSpace(value) + runes := []rune(value) + switch { + case len(runes) == 0: + return "" + case len(runes) <= 4: + return string(runes[0]) + "***" + case len(runes) <= 8: + return string(runes[:2]) + "***" + string(runes[len(runes)-1:]) + default: + return string(runes[:3]) + "***" + string(runes[len(runes)-3:]) + } +} + +func cloneAnyMap(values map[string]any) map[string]any { + if len(values) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + // ChangePassword 修改密码 // Security: Increments TokenVersion to invalidate all existing JWT tokens func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error { diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts index 574e1e36..9c0b4d55 100644 --- a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts +++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts @@ -12,6 +12,8 @@ describe('oauth adoption auth api', () => { beforeEach(() => { post.mockReset() post.mockResolvedValue({ data: {} }) + localStorage.clear() + document.cookie = 'oauth_bind_access_token=; Max-Age=0; path=/' }) it('posts adoption decisions when exchanging pending oauth completion', async () => { @@ -57,4 +59,43 @@ describe('oauth adoption auth api', () => { adopt_avatar: true }) }) + + it('posts wechat invitation completion with adoption decisions', async () => { + const { completeWeChatOAuthRegistration } = await import('@/api/auth') + + await completeWeChatOAuthRegistration('invite-code', { + adoptDisplayName: true, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: true + }) + }) + + it('classifies oauth completion results as login or bind', async () => { + const { getOAuthCompletionKind } = await import('@/api/auth') + + expect(getOAuthCompletionKind({ access_token: 'access-token' })).toBe('login') + expect(getOAuthCompletionKind({ redirect: '/profile' })).toBe('bind') + }) + + it('prepares an oauth bind access token cookie before redirect binding', async () => { + localStorage.setItem('auth_token', 'access-token-value') + const setCookie = vi.fn() + Object.defineProperty(document, 'cookie', { + configurable: true, + get: () => '', + set: setCookie + }) + + const { prepareOAuthBindAccessTokenCookie } = await import('@/api/auth') + + prepareOAuthBindAccessTokenCookie() + + expect(setCookie).toHaveBeenCalledTimes(1) + expect(setCookie.mock.calls[0]?.[0]).toContain('oauth_bind_access_token=access-token-value') + }) }) diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 10b6ca58..c11bd90b 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -186,11 +186,14 @@ export interface RefreshTokenResponse { token_type: string } -export interface PendingOAuthExchangeResponse { - access_token?: string +export interface OAuthTokenResponse { + access_token: string refresh_token?: string expires_in?: number token_type?: string +} + +export interface PendingOAuthExchangeResponse extends Partial { redirect?: string error?: string adoption_required?: boolean @@ -198,6 +201,8 @@ export interface PendingOAuthExchangeResponse { suggested_avatar_url?: string } +export type OAuthCompletionKind = 'login' | 'bind' + export interface OAuthAdoptionDecision { adoptDisplayName?: boolean adoptAvatar?: boolean @@ -218,6 +223,56 @@ function serializeOAuthAdoptionDecision( return payload } +export function isOAuthLoginCompletion( + completion: Partial +): completion is OAuthTokenResponse { + return typeof completion.access_token === 'string' && completion.access_token.trim().length > 0 +} + +export function getOAuthCompletionKind( + completion: Partial +): OAuthCompletionKind { + return isOAuthLoginCompletion(completion) ? 'login' : 'bind' +} + +export function persistOAuthTokenContext(tokens: Partial): void { + if (tokens.refresh_token) { + setRefreshToken(tokens.refresh_token) + } + if (tokens.expires_in) { + setTokenExpiresAt(tokens.expires_in) + } +} + +export function prepareOAuthBindAccessTokenCookie(): void { + if (typeof document === 'undefined' || typeof window === 'undefined') { + return + } + + const token = getAuthToken() + if (!token) { + return + } + + const secure = window.location.protocol === 'https:' ? '; Secure' : '' + const path = resolveOAuthBindCookiePath() + document.cookie = + `oauth_bind_access_token=${encodeURIComponent(token)}; Path=${path}/auth/oauth; Max-Age=600; SameSite=Lax${secure}` +} + +function resolveOAuthBindCookiePath(): string { + const apiBase = ((import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1').replace(/\/$/, '') + + try { + return new URL(apiBase, window.location.origin).pathname.replace(/\/$/, '') || '/api/v1' + } catch { + if (apiBase.startsWith('/')) { + return apiBase + } + return '/api/v1' + } +} + /** * Refresh the access token using the refresh token * @returns New token pair @@ -375,13 +430,8 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { - const { data } = await apiClient.post<{ - access_token: string - refresh_token: string - expires_in: number - token_type: string - }>('/auth/oauth/linuxdo/complete-registration', { +): Promise { + const { data } = await apiClient.post('/auth/oauth/linuxdo/complete-registration', { invitation_code: invitationCode, ...serializeOAuthAdoptionDecision(decision) }) @@ -396,13 +446,19 @@ export async function completeLinuxDoOAuthRegistration( export async function completeOIDCOAuthRegistration( invitationCode: string, decision?: OAuthAdoptionDecision -): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> { - const { data } = await apiClient.post<{ - access_token: string - refresh_token: string - expires_in: number - token_type: string - }>('/auth/oauth/oidc/complete-registration', { +): Promise { + const { data } = await apiClient.post('/auth/oauth/oidc/complete-registration', { + invitation_code: invitationCode, + ...serializeOAuthAdoptionDecision(decision) + }) + return data +} + +export async function completeWeChatOAuthRegistration( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post('/auth/oauth/wechat/complete-registration', { invitation_code: invitationCode, ...serializeOAuthAdoptionDecision(decision) }) @@ -444,7 +500,8 @@ export const authAPI = { revokeAllSessions, exchangePendingOAuthCompletion, completeLinuxDoOAuthRegistration, - completeOIDCOAuthRegistration + completeOIDCOAuthRegistration, + completeWeChatOAuthRegistration } export default authAPI diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts index cd648270..1f6e4cd9 100644 --- a/frontend/src/api/user.ts +++ b/frontend/src/api/user.ts @@ -4,7 +4,8 @@ */ import { apiClient } from './client' -import type { User, ChangePasswordRequest, NotifyEmailEntry } from '@/types' +import { prepareOAuthBindAccessTokenCookie } from './auth' +import type { User, ChangePasswordRequest, NotifyEmailEntry, UserAuthProvider } from '@/types' /** * Get current user profile @@ -83,6 +84,49 @@ export async function toggleNotifyEmail(email: string, disabled: boolean): Promi return data } +export type BindableOAuthProvider = Exclude + +interface BuildOAuthBindingStartURLOptions { + redirectTo?: string +} + +export function resolveWeChatOAuthMode(): 'open' | 'mp' { + if (typeof navigator === 'undefined') { + return 'open' + } + return /MicroMessenger/i.test(navigator.userAgent) ? 'mp' : 'open' +} + +export function buildOAuthBindingStartURL( + provider: BindableOAuthProvider, + options: BuildOAuthBindingStartURLOptions = {} +): string { + const redirectTo = options.redirectTo?.trim() || '/profile' + const apiBase = (import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1' + const normalized = apiBase.replace(/\/$/, '') + const params = new URLSearchParams({ + redirect: redirectTo, + intent: 'bind_current_user' + }) + + if (provider === 'wechat') { + params.set('mode', resolveWeChatOAuthMode()) + } + + return `${normalized}/auth/oauth/${provider}/start?${params.toString()}` +} + +export function startOAuthBinding( + provider: BindableOAuthProvider, + options: BuildOAuthBindingStartURLOptions = {} +): void { + if (typeof window === 'undefined') { + return + } + prepareOAuthBindAccessTokenCookie() + window.location.href = buildOAuthBindingStartURL(provider, options) +} + export const userAPI = { getProfile, updateProfile, @@ -90,7 +134,9 @@ export const userAPI = { sendNotifyEmailCode, verifyNotifyEmail, removeNotifyEmail, - toggleNotifyEmail + toggleNotifyEmail, + buildOAuthBindingStartURL, + startOAuthBinding } export default userAPI diff --git a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue new file mode 100644 index 00000000..b767b2f3 --- /dev/null +++ b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue @@ -0,0 +1,144 @@ + + + diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue index b6f6022d..e82ae229 100644 --- a/frontend/src/components/user/profile/ProfileInfoCard.vue +++ b/frontend/src/components/user/profile/ProfileInfoCard.vue @@ -4,11 +4,16 @@ class="border-b border-gray-100 bg-gradient-to-r from-primary-500/10 to-primary-600/5 px-6 py-5 dark:border-dark-700 dark:from-primary-500/20 dark:to-primary-600/10" >
-
- {{ user?.email?.charAt(0).toUpperCase() || 'U' }} + + {{ avatarInitial }}

@@ -41,18 +46,163 @@ {{ user.username }}

+ +
+
+ + {{ hint.text }} +
+
+ + diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts new file mode 100644 index 00000000..1c9531e3 --- /dev/null +++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts @@ -0,0 +1,120 @@ +import { mount } from '@vue/test-utils' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue' +import type { User } from '@/types' + +const routeState = vi.hoisted(() => ({ + fullPath: '/profile', +})) + +const locationState = vi.hoisted(() => ({ + current: { href: 'http://localhost/profile' } as { href: string }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, +})) + +vi.mock('vue-i18n', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'profile.authBindings.title') return 'Connected sign-in methods' + if (key === 'profile.authBindings.description') return 'Manage bound providers' + if (key === 'profile.authBindings.status.bound') return 'Bound' + if (key === 'profile.authBindings.status.notBound') return 'Not bound' + if (key === 'profile.authBindings.providers.email') return 'Email' + if (key === 'profile.authBindings.providers.linuxdo') return 'LinuxDo' + if (key === 'profile.authBindings.providers.wechat') return 'WeChat' + if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC' + if (key === 'profile.authBindings.bindAction') return `Bind ${params?.providerName || ''}`.trim() + return key + }, + }), + } +}) + +function createUser(overrides: Partial = {}): User { + return { + id: 7, + username: 'alice', + email: 'alice@example.com', + role: 'user', + balance: 10, + concurrency: 2, + status: 'active', + allowed_groups: null, + balance_notify_enabled: true, + balance_notify_threshold: null, + balance_notify_extra_emails: [], + created_at: '2026-04-20T00:00:00Z', + updated_at: '2026-04-20T00:00:00Z', + ...overrides, + } +} + +describe('ProfileIdentityBindingsSection', () => { + beforeEach(() => { + routeState.fullPath = '/profile' + locationState.current = { href: 'http://localhost/profile' } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0', + }) + }) + + afterEach(() => { + vi.unstubAllGlobals() + }) + + it('renders provider binding states and provider-specific bind actions', () => { + const wrapper = mount(ProfileIdentityBindingsSection, { + props: { + user: createUser({ + auth_bindings: { + email: { bound: true }, + linuxdo: { bound: true }, + oidc: { bound: false }, + wechat: false, + }, + }), + linuxdoEnabled: true, + oidcEnabled: true, + oidcProviderName: 'ExampleID', + wechatEnabled: true, + }, + }) + + expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound') + expect(wrapper.get('[data-testid="profile-binding-linuxdo-status"]').text()).toBe('Bound') + expect(wrapper.get('[data-testid="profile-binding-oidc-status"]').text()).toBe('Not bound') + expect(wrapper.get('[data-testid="profile-binding-oidc-action"]').text()).toBe( + 'Bind ExampleID' + ) + expect(wrapper.get('[data-testid="profile-binding-wechat-action"]').text()).toBe('Bind WeChat') + }) + + it('starts the WeChat bind flow for the current profile page', async () => { + const wrapper = mount(ProfileIdentityBindingsSection, { + props: { + user: createUser(), + linuxdoEnabled: false, + oidcEnabled: false, + wechatEnabled: true, + }, + }) + + await wrapper.get('[data-testid="profile-binding-wechat-action"]').trigger('click') + + expect(locationState.current.href).toContain('/api/v1/auth/oauth/wechat/start?') + expect(locationState.current.href).toContain('mode=open') + expect(locationState.current.href).toContain('intent=bind_current_user') + expect(locationState.current.href).toContain('redirect=%2Fprofile') + }) +}) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 684c196f..7d058a74 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -940,6 +940,26 @@ export default { maxEmailsReached: 'Maximum number of notification emails reached', unverified: 'Unverified', verified: 'Verified', + }, + authBindings: { + title: 'Connected Sign-In Methods', + description: 'View current bindings and connect another provider to this account.', + bindAction: 'Bind {providerName}', + bindSuccess: 'Account linked successfully', + status: { + bound: 'Bound', + notBound: 'Not bound', + }, + providers: { + email: 'Email', + linuxdo: 'LinuxDo', + oidc: '{providerName}', + wechat: 'WeChat', + }, + source: { + avatar: 'Avatar is currently synced from {providerName}', + username: 'Nickname is currently synced from {providerName}', + }, } }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 2a4c69a5..6dd74334 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -944,6 +944,26 @@ export default { maxEmailsReached: '已达到通知邮箱数量上限', unverified: '未验证', verified: '已验证', + }, + authBindings: { + title: '登录方式绑定', + description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。', + bindAction: '绑定 {providerName}', + bindSuccess: '账号绑定成功', + status: { + bound: '已绑定', + notBound: '未绑定', + }, + providers: { + email: '邮箱', + linuxdo: 'LinuxDo', + oidc: '{providerName}', + wechat: '微信', + }, + source: { + avatar: '头像当前来自 {providerName}', + username: '昵称当前来自 {providerName}', + }, } }, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 9c9722a9..a19d6c26 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -34,10 +34,47 @@ export interface NotifyEmailEntry { // ==================== User & Auth Types ==================== +export type UserAuthProvider = 'email' | 'linuxdo' | 'oidc' | 'wechat' + +export interface UserAuthBindingStatus { + bound?: boolean + provider?: UserAuthProvider | string + provider_key?: string | null + provider_subject?: string | null + issuer?: string | null + label?: string | null + provider_label?: string | null + metadata?: Record +} + +export interface UserProfileSourceContext { + provider?: UserAuthProvider | string + source?: string | null + label?: string | null + provider_label?: string | null +} + export interface User { id: number username: string email: string + avatar_url?: string | null + avatar_source?: string | UserProfileSourceContext | null + username_source?: string | UserProfileSourceContext | null + display_name_source?: string | UserProfileSourceContext | null + nickname_source?: string | UserProfileSourceContext | null + profile_sources?: { + avatar?: string | UserProfileSourceContext | null + username?: string | UserProfileSourceContext | null + display_name?: string | UserProfileSourceContext | null + nickname?: string | UserProfileSourceContext | null + } + auth_bindings?: Partial> + identity_bindings?: Partial> + email_bound?: boolean + linuxdo_bound?: boolean + oidc_bound?: boolean + wechat_bound?: boolean role: 'admin' | 'user' // User role for authorization balance: number // User balance for API usage concurrency: number // Allowed concurrent requests diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index 0a125def..6dc8f242 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -136,6 +136,9 @@ import { useAuthStore, useAppStore } from '@/stores' import { completeLinuxDoOAuthRegistration, exchangePendingOAuthCompletion, + getOAuthCompletionKind, + isOAuthLoginCompletion, + persistOAuthTokenContext, type OAuthAdoptionDecision, type PendingOAuthExchangeResponse } from '@/api/auth' @@ -162,6 +165,7 @@ const suggestedAvatarUrl = ref('') const adoptDisplayName = ref(true) const adoptAvatar = ref(true) const needsAdoptionConfirmation = ref(false) +const bindSuccessMessage = t('profile.authBindings.bindSuccess') function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -209,18 +213,19 @@ function hasSuggestedProfile(completion: { return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) } -async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { - if (!completion.access_token) { +async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { + if (getOAuthCompletionKind(completion) === 'bind') { + const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') + appStore.showSuccess(bindSuccessMessage) + await router.replace(bindRedirect) + return + } + + if (!isOAuthLoginCompletion(completion)) { throw new Error(t('auth.linuxdo.callbackMissingToken')) } - if (completion.refresh_token) { - localStorage.setItem('refresh_token', completion.refresh_token) - } - if (completion.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) - } - + persistOAuthTokenContext(completion) await authStore.setToken(completion.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirect) @@ -236,12 +241,7 @@ async function handleSubmitInvitation() { invitationCode.value.trim(), currentAdoptionDecision() ) - if (tokenData.refresh_token) { - localStorage.setItem('refresh_token', tokenData.refresh_token) - } - if (tokenData.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000)) - } + persistOAuthTokenContext(tokenData) await authStore.setToken(tokenData.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirectTo.value) @@ -258,7 +258,7 @@ async function handleContinueLogin() { isSubmitting.value = true try { const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) - await finalizeLogin(completion, redirectTo.value) + await finalizeCompletion(completion, redirectTo.value) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = @@ -305,7 +305,7 @@ onMounted(async () => { return } - await finalizeLogin(completion, redirect) + await finalizeCompletion(completion, redirect) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue index 55f8af6e..83304226 100644 --- a/frontend/src/views/auth/OidcCallbackView.vue +++ b/frontend/src/views/auth/OidcCallbackView.vue @@ -145,7 +145,10 @@ import { useAuthStore, useAppStore } from '@/stores' import { completeOIDCOAuthRegistration, exchangePendingOAuthCompletion, + getOAuthCompletionKind, getPublicSettings, + isOAuthLoginCompletion, + persistOAuthTokenContext, type OAuthAdoptionDecision, type PendingOAuthExchangeResponse } from '@/api/auth' @@ -172,6 +175,7 @@ const suggestedAvatarUrl = ref('') const adoptDisplayName = ref(true) const adoptAvatar = ref(true) const needsAdoptionConfirmation = ref(false) +const bindSuccessMessage = t('profile.authBindings.bindSuccess') function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -231,18 +235,19 @@ function hasSuggestedProfile(completion: { return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) } -async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { - if (!completion.access_token) { +async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { + if (getOAuthCompletionKind(completion) === 'bind') { + const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') + appStore.showSuccess(bindSuccessMessage) + await router.replace(bindRedirect) + return + } + + if (!isOAuthLoginCompletion(completion)) { throw new Error(t('auth.oidc.callbackMissingToken')) } - if (completion.refresh_token) { - localStorage.setItem('refresh_token', completion.refresh_token) - } - if (completion.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) - } - + persistOAuthTokenContext(completion) await authStore.setToken(completion.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirect) @@ -258,12 +263,7 @@ async function handleSubmitInvitation() { invitationCode.value.trim(), currentAdoptionDecision() ) - if (tokenData.refresh_token) { - localStorage.setItem('refresh_token', tokenData.refresh_token) - } - if (tokenData.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000)) - } + persistOAuthTokenContext(tokenData) await authStore.setToken(tokenData.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirectTo.value) @@ -280,7 +280,7 @@ async function handleContinueLogin() { isSubmitting.value = true try { const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) - await finalizeLogin(completion, redirectTo.value) + await finalizeCompletion(completion, redirectTo.value) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = @@ -329,7 +329,7 @@ onMounted(async () => { return } - await finalizeLogin(completion, redirect) + await finalizeCompletion(completion, redirect) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue index 407b395b..ac0ede4c 100644 --- a/frontend/src/views/auth/WechatCallbackView.vue +++ b/frontend/src/views/auth/WechatCallbackView.vue @@ -140,27 +140,16 @@ import { useRoute, useRouter } from 'vue-router' import { useI18n } from 'vue-i18n' import { AuthLayout } from '@/components/layout' import Icon from '@/components/icons/Icon.vue' -import { apiClient } from '@/api/client' import { useAuthStore, useAppStore } from '@/stores' - -interface OAuthTokenResponse { - access_token: string - refresh_token: string - expires_in: number - token_type: string -} - -interface PendingOAuthExchangeResponse { - access_token?: string - refresh_token?: string - expires_in?: number - token_type?: string - redirect?: string - error?: string - adoption_required?: boolean - suggested_display_name?: string - suggested_avatar_url?: string -} +import { + completeWeChatOAuthRegistration, + exchangePendingOAuthCompletion, + getOAuthCompletionKind, + isOAuthLoginCompletion, + persistOAuthTokenContext, + type OAuthAdoptionDecision, + type PendingOAuthExchangeResponse +} from '@/api/auth' const route = useRoute() const router = useRouter() @@ -182,6 +171,7 @@ const suggestedAvatarUrl = ref('') const adoptDisplayName = ref(true) const adoptAvatar = ref(true) const needsAdoptionConfirmation = ref(false) +const bindSuccessMessage = t('profile.authBindings.bindSuccess') const providerName = 'WeChat' @@ -200,10 +190,10 @@ function sanitizeRedirectPath(path: string | null | undefined): string { return path } -function currentAdoptionDecision(): Record { +function currentAdoptionDecision(): OAuthAdoptionDecision { return { - adopt_display_name: adoptDisplayName.value, - adopt_avatar: adoptAvatar.value, + adoptDisplayName: adoptDisplayName.value, + adoptAvatar: adoptAvatar.value } } @@ -224,49 +214,35 @@ function hasSuggestedProfile(completion: PendingOAuthExchangeResponse): boolean return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) } -async function exchangePendingOAuthCompletion(): Promise { - const { data } = await apiClient.post('/auth/oauth/pending/exchange', {}) - return data -} +async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { + if (getOAuthCompletionKind(completion) === 'bind') { + const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') + appStore.showSuccess(bindSuccessMessage) + await router.replace(bindRedirect) + return + } -async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { - if (!completion.access_token) { + if (!isOAuthLoginCompletion(completion)) { throw new Error(t('auth.oidc.callbackMissingToken')) } - if (completion.refresh_token) { - localStorage.setItem('refresh_token', completion.refresh_token) - } - if (completion.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) - } - + persistOAuthTokenContext(completion) await authStore.setToken(completion.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirect) } -async function completeWeChatOAuthRegistration(invitation: string): Promise { - const { data } = await apiClient.post('/auth/oauth/wechat/complete-registration', { - invitation_code: invitation, - ...currentAdoptionDecision(), - }) - return data -} - async function handleSubmitInvitation() { invitationError.value = '' if (!invitationCode.value.trim()) return isSubmitting.value = true try { - const tokenData = await completeWeChatOAuthRegistration(invitationCode.value.trim()) - if (tokenData.refresh_token) { - localStorage.setItem('refresh_token', tokenData.refresh_token) - } - if (tokenData.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000)) - } + const tokenData = await completeWeChatOAuthRegistration( + invitationCode.value.trim(), + currentAdoptionDecision() + ) + persistOAuthTokenContext(tokenData) await authStore.setToken(tokenData.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirectTo.value) @@ -282,11 +258,8 @@ async function handleSubmitInvitation() { async function handleContinueLogin() { isSubmitting.value = true try { - const { data } = await apiClient.post( - '/auth/oauth/pending/exchange', - currentAdoptionDecision() - ) - await finalizeLogin(data, redirectTo.value) + const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) + await finalizeCompletion(completion, redirectTo.value) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = @@ -333,7 +306,7 @@ onMounted(async () => { return } - await finalizeLogin(completion, redirect) + await finalizeCompletion(completion, redirect) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts index 60a40474..7ffdcd19 100644 --- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts @@ -39,10 +39,14 @@ vi.mock('@/stores', () => ({ }) })) -vi.mock('@/api/auth', () => ({ - exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), - completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args) -})) +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args) + } +}) describe('LinuxDoCallbackView', () => { beforeEach(() => { @@ -132,6 +136,64 @@ describe('LinuxDoCallbackView', () => { expect(replace).toHaveBeenCalledWith('/dashboard') }) + it('treats a completion without token as bind success and returns to profile', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile') + }) + + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + redirect: '/profile/security' + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile/security') + }) + it('renders adoption choices for invitation flow and submits the selected values', async () => { exchangePendingOAuthCompletion.mockResolvedValue({ error: 'invitation_required', diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts index 299c0746..f8de79f2 100644 --- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts @@ -45,11 +45,15 @@ vi.mock('@/stores', () => ({ }) })) -vi.mock('@/api/auth', () => ({ - exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), - completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), - getPublicSettings: (...args: any[]) => getPublicSettings(...args) -})) +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), + getPublicSettings: (...args: any[]) => getPublicSettings(...args) + } +}) describe('OidcCallbackView', () => { beforeEach(() => { @@ -143,6 +147,43 @@ describe('OidcCallbackView', () => { expect(replace).toHaveBeenCalledWith('/dashboard') }) + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + redirect: '/profile' + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile') + }) + it('renders adoption choices for invitation flow and submits the selected values', async () => { exchangePendingOAuthCompletion.mockResolvedValue({ error: 'invitation_required', diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts index a9e2ada2..896bf15d 100644 --- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts @@ -3,14 +3,16 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import WechatCallbackView from '@/views/auth/WechatCallbackView.vue' const { - postMock, + exchangePendingOAuthCompletionMock, + completeWeChatOAuthRegistrationMock, replaceMock, setTokenMock, showSuccessMock, showErrorMock, routeState, } = vi.hoisted(() => ({ - postMock: vi.fn(), + exchangePendingOAuthCompletionMock: vi.fn(), + completeWeChatOAuthRegistrationMock: vi.fn(), replaceMock: vi.fn(), setTokenMock: vi.fn(), showSuccessMock: vi.fn(), @@ -86,15 +88,19 @@ vi.mock('@/stores', () => ({ }), })) -vi.mock('@/api/client', () => ({ - apiClient: { - post: postMock, - }, -})) +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletionMock(...args), + completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args), + } +}) describe('WechatCallbackView', () => { beforeEach(() => { - postMock.mockReset() + exchangePendingOAuthCompletionMock.mockReset() + completeWeChatOAuthRegistrationMock.mockReset() replaceMock.mockReset() setTokenMock.mockReset() showSuccessMock.mockReset() @@ -104,14 +110,12 @@ describe('WechatCallbackView', () => { }) it('does not send adoption decisions during the initial exchange', async () => { - postMock.mockResolvedValueOnce({ - data: { - access_token: 'access-token', - refresh_token: 'refresh-token', - expires_in: 3600, - redirect: '/dashboard', - adoption_required: true, - }, + exchangePendingOAuthCompletionMock.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true, }) setTokenMock.mockResolvedValue({}) @@ -128,28 +132,24 @@ describe('WechatCallbackView', () => { await flushPromises() - expect(postMock).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {}) - expect(postMock).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledWith() + expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledTimes(1) }) it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { - postMock + exchangePendingOAuthCompletionMock .mockResolvedValueOnce({ - data: { - redirect: '/dashboard', - adoption_required: true, - suggested_display_name: 'WeChat Nick', - suggested_avatar_url: 'https://cdn.example/wechat.png', - }, + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', }) .mockResolvedValueOnce({ - data: { - access_token: 'wechat-access-token', - refresh_token: 'wechat-refresh-token', - expires_in: 3600, - token_type: 'Bearer', - redirect: '/dashboard', - }, + access_token: 'wechat-access-token', + refresh_token: 'wechat-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + redirect: '/dashboard', }) setTokenMock.mockResolvedValue({}) @@ -179,35 +179,67 @@ describe('WechatCallbackView', () => { await buttons[0].trigger('click') await flushPromises() - expect(postMock).toHaveBeenNthCalledWith(1, '/auth/oauth/pending/exchange', {}) - expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/pending/exchange', { - adopt_display_name: true, - adopt_avatar: false, + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: false, }) expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token') expect(replaceMock).toHaveBeenCalledWith('/dashboard') expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token') }) + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletionMock + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + .mockResolvedValueOnce({ + redirect: '/profile/connections', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true, + }) + expect(setTokenMock).not.toHaveBeenCalled() + expect(showSuccessMock).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replaceMock).toHaveBeenCalledWith('/profile/connections') + }) + it('renders adoption choices for invitation flow and submits the selected values', async () => { - postMock - .mockResolvedValueOnce({ - data: { - error: 'invitation_required', - redirect: '/subscriptions', - adoption_required: true, - suggested_display_name: 'WeChat Nick', - suggested_avatar_url: 'https://cdn.example/wechat.png', - }, - }) - .mockResolvedValueOnce({ - data: { - access_token: 'wechat-invite-token', - refresh_token: 'wechat-invite-refresh', - expires_in: 600, - token_type: 'Bearer', - }, - }) + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/subscriptions', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + completeWeChatOAuthRegistrationMock.mockResolvedValue({ + access_token: 'wechat-invite-token', + refresh_token: 'wechat-invite-refresh', + expires_in: 600, + token_type: 'Bearer', + }) const wrapper = mount(WechatCallbackView, { global: { @@ -230,10 +262,9 @@ describe('WechatCallbackView', () => { await wrapper.get('button').trigger('click') await flushPromises() - expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/wechat/complete-registration', { - invitation_code: 'INVITE-CODE', - adopt_display_name: false, - adopt_avatar: true, + expect(completeWeChatOAuthRegistrationMock).toHaveBeenCalledWith('INVITE-CODE', { + adoptDisplayName: false, + adoptAvatar: true, }) expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token') expect(replaceMock).toHaveBeenCalledWith('/subscriptions') diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue index e7418ebb..f7418be9 100644 --- a/frontend/src/views/user/ProfileView.vue +++ b/frontend/src/views/user/ProfileView.vue @@ -2,18 +2,53 @@
- - - + + +
- -
+ + + +
-
-

{{ t('common.contactSupport') }}

{{ contactInfo }}

+
+ +
+
+

+ {{ t('common.contactSupport') }} +

+

{{ contactInfo }}

+
+ + +
@@ -29,26 +65,78 @@ \ No newline at end of file +onMounted(async () => { + const profileRefresh = authStore.refreshUser().catch((error) => { + console.error('Failed to refresh profile:', error) + }) + + const settingsLoad = authAPI.getPublicSettings() + .then((settings) => { + contactInfo.value = settings.contact_info || '' + balanceLowNotifyEnabled.value = settings.balance_low_notify_enabled ?? false + systemDefaultThreshold.value = settings.balance_low_notify_threshold ?? 0 + linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled ?? false + wechatOAuthEnabled.value = settings.wechat_oauth_enabled ?? false + oidcOAuthEnabled.value = settings.oidc_oauth_enabled ?? false + oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC' + }) + .catch((error) => { + console.error('Failed to load settings:', error) + }) + + await Promise.all([profileRefresh, settingsLoad]) +}) + +const formatCurrency = (value: number) => `$${value.toFixed(2)}` +