diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index acd43e9f..ca3a5a77 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -78,9 +78,24 @@ type AuthResponse struct { User *dto.User `json:"user"` } +func ensureLoginUserActive(user *service.User) error { + if user == nil { + return infraerrors.Unauthorized("INVALID_USER", "user not found") + } + if !user.IsActive() { + return service.ErrUserNotActive + } + return nil +} + // respondWithTokenPair 生成 Token 对并返回认证响应 // 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容) func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) { + if err := ensureLoginUserActive(user); err != nil { + response.ErrorFrom(c, err) + return + } + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") if err != nil { slog.Error("failed to generate token pair", "error", err, "user_id", user.ID) @@ -293,6 +308,10 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { response.ErrorFrom(c, err) return } + if err := ensureLoginUserActive(user); err != nil { + response.ErrorFrom(c, err) + return + } if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 157be066..a7e77c09 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -495,7 +495,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ AdoptDisplayName: req.AdoptDisplayName, AdoptAvatar: req.AdoptAvatar, }) diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index a9a5e3e6..d535c178 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -408,6 +408,74 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t require.Nil(t, completion["error"]) } +func TestLinuxDoOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_disabled","name":"LinuxDo Disabled"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(linuxDoSyntheticEmail("654")). + SetUsername("disabled-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("654"). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-disabled&state=state-disabled", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-disabled")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-disabled")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { @@ -812,6 +880,69 @@ func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillReq require.Nil(t, storedSession.ConsumedAt) } +func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-no-adoption-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-subject-no-adoption"). + SetResolvedEmail("linuxdo-subject-no-adoption@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-browser-no-adoption"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "LinuxDo Legacy", + "suggested_avatar_url": "https://cdn.example/linuxdo-legacy.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser-no-adoption")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + require.NotEmpty(t, responseData["refresh_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "linuxdo_user", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("linuxdo-subject-no-adoption"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) +} + func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { t.Helper() handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index c5df4db1..7be01e74 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -464,15 +464,7 @@ func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity servic } return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) } - - userEntity, err := client.User.Get(ctx, record.UserID) - if err != nil { - if dbent.IsNotFound(err) { - return nil, nil - } - return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err) - } - return userEntity, nil + return findActiveUserByID(ctx, client, record.UserID) } func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") } @@ -998,6 +990,9 @@ func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64) } return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err) } + if !strings.EqualFold(strings.TrimSpace(userEntity.Status), service.StatusActive) { + return nil, service.ErrUserNotActive + } return userEntity, nil } @@ -1801,6 +1796,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { response.ErrorFrom(c, err) return } + if err := ensureLoginUserActive(loginUser); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil { clearCookies() response.ErrorFrom(c, err) diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index b3b8dfe1..bc8fe7eb 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -851,6 +851,56 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl require.Nil(t, storedSession.ConsumedAt) } +func TestExchangePendingOAuthCompletionRejectsDisabledTargetUser(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("disabled-linked@example.com"). + SetUsername("disabled-linked-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("disabled-linked-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("disabled-linked-subject"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("disabled-linked-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Disabled Linked User", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("disabled-linked-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) { payload := normalizePendingOAuthCompletionResponse(map[string]any{ "access_token": "legacy-access-token", diff --git a/backend/internal/handler/auth_oauth_test_helpers_test.go b/backend/internal/handler/auth_oauth_test_helpers_test.go index 8eb87dbb..47bad942 100644 --- a/backend/internal/handler/auth_oauth_test_helpers_test.go +++ b/backend/internal/handler/auth_oauth_test_helpers_test.go @@ -2,6 +2,7 @@ package handler import ( "net/http" + "net/url" "testing" "github.com/stretchr/testify/require" @@ -37,3 +38,20 @@ func decodeCookieValueForTest(t *testing.T, value string) string { require.NoError(t, err) return decoded } + +func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) { + t.Helper() + require.NotEmpty(t, location) + + parsed, err := url.Parse(location) + require.NoError(t, err) + + rawValues := parsed.RawQuery + if rawValues == "" { + rawValues = parsed.Fragment + } + values, err := url.ParseQuery(rawValues) + require.NoError(t, err) + require.Equal(t, errorCode, values.Get("error")) + require.Equal(t, errorMessage, values.Get("error_message")) +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 6345938b..3c67e421 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -648,7 +648,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ AdoptDisplayName: req.AdoptDisplayName, AdoptAvatar: req.AdoptAvatar, }) diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index 63008344..c2855dc9 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -340,6 +340,56 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *t require.Nil(t, completion["error"]) } +func TestOIDCOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-disabled-subject", + PreferredUsername: "oidc_disabled", + DisplayName: "OIDC Disabled", + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-disabled-subject"))). + SetUsername("disabled-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("oidc"). + SetProviderKey(cfg.IssuerURL). + SetProviderSubject("oidc-disabled-subject"). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-disabled", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-disabled")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-disabled")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-disabled-subject")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) { cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ Subject: "oidc-subject-compat", @@ -748,6 +798,70 @@ func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequir require.Nil(t, storedSession.ConsumedAt) } +func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-no-adoption-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-subject-no-adoption"). + SetResolvedEmail("8c9f12b2a2e14b1db9efc08b27e0ef5c@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-browser-no-adoption"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + "suggested_display_name": "OIDC Legacy", + "suggested_avatar_url": "https://cdn.example/oidc-legacy.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser-no-adoption")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + require.NotEmpty(t, responseData["refresh_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "oidc_user", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example.com"), + authidentity.ProviderSubjectEQ("oidc-subject-no-adoption"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) +} + type oidcProviderFixture struct { Subject string PreferredUsername string diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 3ed20a7d..dc93fcae 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -551,7 +551,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ AdoptDisplayName: req.AdoptDisplayName, AdoptAvatar: req.AdoptAvatar, }) @@ -827,7 +827,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) } if user, err := singleWeChatIdentityUser(records); err != nil || user != nil { - return user, err + if err != nil || user == nil { + return user, err + } + return findActiveUserByID(ctx, client, user.ID) } } @@ -851,7 +854,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) } if user, err := singleWeChatChannelUser(records); err != nil || user != nil { - return user, err + if err != nil || user == nil { + return user, err + } + return findActiveUserByID(ctx, client, user.ID) } } @@ -870,7 +876,11 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( if err != nil { return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) } - return singleWeChatIdentityUser(records) + user, err := singleWeChatIdentityUser(records) + if err != nil || user == nil { + return user, err + } + return findActiveUserByID(ctx, client, user.ID) } func wechatCompatibleProviderKeys(providerKey string) []string { diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 349e7dd2..b8bd21ce 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/repository" @@ -292,6 +293,71 @@ func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWit require.False(t, hasRefreshToken) } +func TestWeChatOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-disabled","unionid":"union-disabled","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-disabled","unionid":"union-disabled","nickname":"Disabled WeChat","headimgurl":"https://cdn.example/disabled.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(wechatSyntheticEmail("union-disabled")). + SetUsername("disabled-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("union-disabled"). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-disabled", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-disabled")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL t.Cleanup(func() { @@ -816,6 +882,73 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPend require.Zero(t, decisionCount) } +func TestCompleteWeChatOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("wechat-complete-no-adoption-session"). + SetIntent("login"). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("wechat-subject-no-adoption"). + SetResolvedEmail("wechat-subject-no-adoption@wechat-connect.invalid"). + SetBrowserSessionKey("wechat-browser-no-adoption"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "wechat_user", + "suggested_display_name": "WeChat Legacy", + "suggested_avatar_url": "https://cdn.example/wechat-legacy.png", + "mode": "open", + "channel": "open", + "channel_app_id": "wx-open-app", + "channel_subject": "openid-legacy", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(recorder) + completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + completeReq.Header.Set("Content-Type", "application/json") + completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-browser-no-adoption")}) + completeCtx.Request = completeReq + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + require.NotEmpty(t, responseData["refresh_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "wechat_user", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("wechat-subject-no-adoption"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) +} + func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL