fix(auth): harden oauth callback adoption flows
This commit is contained in:
@@ -78,9 +78,24 @@ type AuthResponse struct {
|
|||||||
User *dto.User `json:"user"`
|
User *dto.User `json:"user"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureLoginUserActive(user *service.User) error {
|
||||||
|
if user == nil {
|
||||||
|
return infraerrors.Unauthorized("INVALID_USER", "user not found")
|
||||||
|
}
|
||||||
|
if !user.IsActive() {
|
||||||
|
return service.ErrUserNotActive
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// respondWithTokenPair 生成 Token 对并返回认证响应
|
// respondWithTokenPair 生成 Token 对并返回认证响应
|
||||||
// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容)
|
// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容)
|
||||||
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
|
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
|
||||||
|
if err := ensureLoginUserActive(user); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
|
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
|
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
|
||||||
@@ -293,6 +308,10 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := ensureLoginUserActive(user); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
|
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
|
|||||||
@@ -495,7 +495,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
||||||
AdoptDisplayName: req.AdoptDisplayName,
|
AdoptDisplayName: req.AdoptDisplayName,
|
||||||
AdoptAvatar: req.AdoptAvatar,
|
AdoptAvatar: req.AdoptAvatar,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -408,6 +408,74 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t
|
|||||||
require.Nil(t, completion["error"])
|
require.Nil(t, completion["error"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLinuxDoOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/token":
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
|
||||||
|
case "/userinfo":
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_disabled","name":"LinuxDo Disabled"}`))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
|
||||||
|
Enabled: true,
|
||||||
|
ClientID: "linuxdo-client",
|
||||||
|
ClientSecret: "linuxdo-secret",
|
||||||
|
AuthorizeURL: upstream.URL + "/authorize",
|
||||||
|
TokenURL: upstream.URL + "/token",
|
||||||
|
UserInfoURL: upstream.URL + "/userinfo",
|
||||||
|
Scopes: "read",
|
||||||
|
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
|
||||||
|
FrontendRedirectURL: "/auth/linuxdo/callback",
|
||||||
|
TokenAuthMethod: "client_secret_post",
|
||||||
|
UsePKCE: true,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() { _ = client.Close() })
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
existingUser, err := client.User.Create().
|
||||||
|
SetEmail(linuxDoSyntheticEmail("654")).
|
||||||
|
SetUsername("disabled-user").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusDisabled).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = client.AuthIdentity.Create().
|
||||||
|
SetUserID(existingUser.ID).
|
||||||
|
SetProviderType("linuxdo").
|
||||||
|
SetProviderKey("linuxdo").
|
||||||
|
SetProviderSubject("654").
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-disabled&state=state-disabled", nil)
|
||||||
|
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-disabled"))
|
||||||
|
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
|
||||||
|
req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-disabled"))
|
||||||
|
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
|
||||||
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.LinuxDoOAuthCallback(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
||||||
|
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
|
||||||
|
|
||||||
|
count, err := client.PendingAuthSession.Query().Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, count)
|
||||||
|
}
|
||||||
|
|
||||||
func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
|
func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
|
||||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
@@ -812,6 +880,69 @@ func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillReq
|
|||||||
require.Nil(t, storedSession.ConsumedAt)
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("linuxdo-complete-no-adoption-session").
|
||||||
|
SetIntent("login").
|
||||||
|
SetProviderType("linuxdo").
|
||||||
|
SetProviderKey("linuxdo").
|
||||||
|
SetProviderSubject("linuxdo-subject-no-adoption").
|
||||||
|
SetResolvedEmail("linuxdo-subject-no-adoption@linuxdo-connect.invalid").
|
||||||
|
SetBrowserSessionKey("linuxdo-browser-no-adoption").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"username": "linuxdo_user",
|
||||||
|
"suggested_display_name": "LinuxDo Legacy",
|
||||||
|
"suggested_avatar_url": "https://cdn.example/linuxdo-legacy.png",
|
||||||
|
}).
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser-no-adoption")})
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.CompleteLinuxDoOAuthRegistration(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
responseData := decodeJSONBody(t, recorder)
|
||||||
|
require.NotEmpty(t, responseData["access_token"])
|
||||||
|
require.NotEmpty(t, responseData["refresh_token"])
|
||||||
|
|
||||||
|
userEntity, err := client.User.Query().
|
||||||
|
Where(dbuser.EmailEQ(session.ResolvedEmail)).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "linuxdo_user", userEntity.Username)
|
||||||
|
|
||||||
|
identity, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ("linuxdo"),
|
||||||
|
authidentity.ProviderKeyEQ("linuxdo"),
|
||||||
|
authidentity.ProviderSubjectEQ("linuxdo-subject-no-adoption"),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, userEntity.ID, identity.UserID)
|
||||||
|
|
||||||
|
decision, err := client.IdentityAdoptionDecision.Query().
|
||||||
|
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, decision.IdentityID)
|
||||||
|
require.Equal(t, identity.ID, *decision.IdentityID)
|
||||||
|
require.False(t, decision.AdoptDisplayName)
|
||||||
|
require.False(t, decision.AdoptAvatar)
|
||||||
|
}
|
||||||
|
|
||||||
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
|
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
|
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
|
||||||
|
|||||||
@@ -464,15 +464,7 @@ func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity servic
|
|||||||
}
|
}
|
||||||
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||||
}
|
}
|
||||||
|
return findActiveUserByID(ctx, client, record.UserID)
|
||||||
userEntity, err := client.User.Get(ctx, record.UserID)
|
|
||||||
if err != nil {
|
|
||||||
if dbent.IsNotFound(err) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
|
|
||||||
}
|
|
||||||
return userEntity, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
|
func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
|
||||||
@@ -998,6 +990,9 @@ func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64)
|
|||||||
}
|
}
|
||||||
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
|
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
|
||||||
}
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(userEntity.Status), service.StatusActive) {
|
||||||
|
return nil, service.ErrUserNotActive
|
||||||
|
}
|
||||||
return userEntity, nil
|
return userEntity, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1801,6 +1796,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := ensureLoginUserActive(loginUser); err != nil {
|
||||||
|
clearCookies()
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil {
|
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil {
|
||||||
clearCookies()
|
clearCookies()
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
|
|||||||
@@ -851,6 +851,56 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl
|
|||||||
require.Nil(t, storedSession.ConsumedAt)
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExchangePendingOAuthCompletionRejectsDisabledTargetUser(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
userEntity, err := client.User.Create().
|
||||||
|
SetEmail("disabled-linked@example.com").
|
||||||
|
SetUsername("disabled-linked-user").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusDisabled).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("disabled-linked-session-token").
|
||||||
|
SetIntent("login").
|
||||||
|
SetProviderType("linuxdo").
|
||||||
|
SetProviderKey("linuxdo").
|
||||||
|
SetProviderSubject("disabled-linked-subject").
|
||||||
|
SetTargetUserID(userEntity.ID).
|
||||||
|
SetResolvedEmail(userEntity.Email).
|
||||||
|
SetBrowserSessionKey("disabled-linked-browser-session-key").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"suggested_display_name": "Disabled Linked User",
|
||||||
|
}).
|
||||||
|
SetLocalFlowState(map[string]any{
|
||||||
|
oauthCompletionResponseKey: map[string]any{
|
||||||
|
"redirect": "/dashboard",
|
||||||
|
},
|
||||||
|
}).
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("disabled-linked-browser-session-key")})
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
handler.ExchangePendingOAuthCompletion(ginCtx)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
|
||||||
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) {
|
func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) {
|
||||||
payload := normalizePendingOAuthCompletionResponse(map[string]any{
|
payload := normalizePendingOAuthCompletionResponse(map[string]any{
|
||||||
"access_token": "legacy-access-token",
|
"access_token": "legacy-access-token",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -37,3 +38,20 @@ func decodeCookieValueForTest(t *testing.T, value string) string {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return decoded
|
return decoded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
|
||||||
|
t.Helper()
|
||||||
|
require.NotEmpty(t, location)
|
||||||
|
|
||||||
|
parsed, err := url.Parse(location)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rawValues := parsed.RawQuery
|
||||||
|
if rawValues == "" {
|
||||||
|
rawValues = parsed.Fragment
|
||||||
|
}
|
||||||
|
values, err := url.ParseQuery(rawValues)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, errorCode, values.Get("error"))
|
||||||
|
require.Equal(t, errorMessage, values.Get("error_message"))
|
||||||
|
}
|
||||||
|
|||||||
@@ -648,7 +648,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
||||||
AdoptDisplayName: req.AdoptDisplayName,
|
AdoptDisplayName: req.AdoptDisplayName,
|
||||||
AdoptAvatar: req.AdoptAvatar,
|
AdoptAvatar: req.AdoptAvatar,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -340,6 +340,56 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *t
|
|||||||
require.Nil(t, completion["error"])
|
require.Nil(t, completion["error"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOIDCOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
|
||||||
|
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
|
||||||
|
Subject: "oidc-disabled-subject",
|
||||||
|
PreferredUsername: "oidc_disabled",
|
||||||
|
DisplayName: "OIDC Disabled",
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
|
||||||
|
t.Cleanup(func() { _ = client.Close() })
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
existingUser, err := client.User.Create().
|
||||||
|
SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-disabled-subject"))).
|
||||||
|
SetUsername("disabled-user").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusDisabled).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = client.AuthIdentity.Create().
|
||||||
|
SetUserID(existingUser.ID).
|
||||||
|
SetProviderType("oidc").
|
||||||
|
SetProviderKey(cfg.IssuerURL).
|
||||||
|
SetProviderSubject("oidc-disabled-subject").
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-disabled", nil)
|
||||||
|
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-disabled"))
|
||||||
|
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
|
||||||
|
req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-disabled"))
|
||||||
|
req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-disabled-subject"))
|
||||||
|
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
|
||||||
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.OIDCOAuthCallback(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
||||||
|
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
|
||||||
|
|
||||||
|
count, err := client.PendingAuthSession.Query().Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, count)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
|
func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
|
||||||
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
|
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
|
||||||
Subject: "oidc-subject-compat",
|
Subject: "oidc-subject-compat",
|
||||||
@@ -748,6 +798,70 @@ func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequir
|
|||||||
require.Nil(t, storedSession.ConsumedAt)
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("oidc-complete-no-adoption-session").
|
||||||
|
SetIntent("login").
|
||||||
|
SetProviderType("oidc").
|
||||||
|
SetProviderKey("https://issuer.example.com").
|
||||||
|
SetProviderSubject("oidc-subject-no-adoption").
|
||||||
|
SetResolvedEmail("8c9f12b2a2e14b1db9efc08b27e0ef5c@oidc-connect.invalid").
|
||||||
|
SetBrowserSessionKey("oidc-browser-no-adoption").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"username": "oidc_user",
|
||||||
|
"issuer": "https://issuer.example.com",
|
||||||
|
"suggested_display_name": "OIDC Legacy",
|
||||||
|
"suggested_avatar_url": "https://cdn.example/oidc-legacy.png",
|
||||||
|
}).
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser-no-adoption")})
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.CompleteOIDCOAuthRegistration(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
responseData := decodeJSONBody(t, recorder)
|
||||||
|
require.NotEmpty(t, responseData["access_token"])
|
||||||
|
require.NotEmpty(t, responseData["refresh_token"])
|
||||||
|
|
||||||
|
userEntity, err := client.User.Query().
|
||||||
|
Where(dbuser.EmailEQ(session.ResolvedEmail)).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "oidc_user", userEntity.Username)
|
||||||
|
|
||||||
|
identity, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ("oidc"),
|
||||||
|
authidentity.ProviderKeyEQ("https://issuer.example.com"),
|
||||||
|
authidentity.ProviderSubjectEQ("oidc-subject-no-adoption"),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, userEntity.ID, identity.UserID)
|
||||||
|
|
||||||
|
decision, err := client.IdentityAdoptionDecision.Query().
|
||||||
|
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, decision.IdentityID)
|
||||||
|
require.Equal(t, identity.ID, *decision.IdentityID)
|
||||||
|
require.False(t, decision.AdoptDisplayName)
|
||||||
|
require.False(t, decision.AdoptAvatar)
|
||||||
|
}
|
||||||
|
|
||||||
type oidcProviderFixture struct {
|
type oidcProviderFixture struct {
|
||||||
Subject string
|
Subject string
|
||||||
PreferredUsername string
|
PreferredUsername string
|
||||||
|
|||||||
@@ -551,7 +551,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
||||||
AdoptDisplayName: req.AdoptDisplayName,
|
AdoptDisplayName: req.AdoptDisplayName,
|
||||||
AdoptAvatar: req.AdoptAvatar,
|
AdoptAvatar: req.AdoptAvatar,
|
||||||
})
|
})
|
||||||
@@ -827,7 +827,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
|
|||||||
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||||
}
|
}
|
||||||
if user, err := singleWeChatIdentityUser(records); err != nil || user != nil {
|
if user, err := singleWeChatIdentityUser(records); err != nil || user != nil {
|
||||||
return user, err
|
if err != nil || user == nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
return findActiveUserByID(ctx, client, user.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -851,7 +854,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
|
|||||||
return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
|
return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
|
||||||
}
|
}
|
||||||
if user, err := singleWeChatChannelUser(records); err != nil || user != nil {
|
if user, err := singleWeChatChannelUser(records); err != nil || user != nil {
|
||||||
return user, err
|
if err != nil || user == nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
return findActiveUserByID(ctx, client, user.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -870,7 +876,11 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||||
}
|
}
|
||||||
return singleWeChatIdentityUser(records)
|
user, err := singleWeChatIdentityUser(records)
|
||||||
|
if err != nil || user == nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
return findActiveUserByID(ctx, client, user.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func wechatCompatibleProviderKeys(providerKey string) []string {
|
func wechatCompatibleProviderKeys(providerKey string) []string {
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||||
|
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||||
@@ -292,6 +293,71 @@ func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWit
|
|||||||
require.False(t, hasRefreshToken)
|
require.False(t, hasRefreshToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWeChatOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
|
||||||
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
||||||
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
||||||
|
t.Cleanup(func() {
|
||||||
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
||||||
|
wechatOAuthUserInfoURL = originalUserInfoURL
|
||||||
|
})
|
||||||
|
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-disabled","unionid":"union-disabled","scope":"snsapi_login"}`))
|
||||||
|
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"openid":"openid-disabled","unionid":"union-disabled","nickname":"Disabled WeChat","headimgurl":"https://cdn.example/disabled.png"}`))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
||||||
|
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
||||||
|
|
||||||
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
existingUser, err := client.User.Create().
|
||||||
|
SetEmail(wechatSyntheticEmail("union-disabled")).
|
||||||
|
SetUsername("disabled-user").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusDisabled).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = client.AuthIdentity.Create().
|
||||||
|
SetUserID(existingUser.ID).
|
||||||
|
SetProviderType("wechat").
|
||||||
|
SetProviderKey(wechatOAuthProviderKey).
|
||||||
|
SetProviderSubject("union-disabled").
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-disabled", nil)
|
||||||
|
req.Host = "api.example.com"
|
||||||
|
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-disabled"))
|
||||||
|
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
||||||
|
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
||||||
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.WeChatOAuthCallback(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
||||||
|
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
|
||||||
|
|
||||||
|
count, err := client.PendingAuthSession.Query().Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, count)
|
||||||
|
}
|
||||||
|
|
||||||
func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
|
func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
|
||||||
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -816,6 +882,73 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPend
|
|||||||
require.Zero(t, decisionCount)
|
require.Zero(t, decisionCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompleteWeChatOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("wechat-complete-no-adoption-session").
|
||||||
|
SetIntent("login").
|
||||||
|
SetProviderType("wechat").
|
||||||
|
SetProviderKey(wechatOAuthProviderKey).
|
||||||
|
SetProviderSubject("wechat-subject-no-adoption").
|
||||||
|
SetResolvedEmail("wechat-subject-no-adoption@wechat-connect.invalid").
|
||||||
|
SetBrowserSessionKey("wechat-browser-no-adoption").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"username": "wechat_user",
|
||||||
|
"suggested_display_name": "WeChat Legacy",
|
||||||
|
"suggested_avatar_url": "https://cdn.example/wechat-legacy.png",
|
||||||
|
"mode": "open",
|
||||||
|
"channel": "open",
|
||||||
|
"channel_app_id": "wx-open-app",
|
||||||
|
"channel_subject": "openid-legacy",
|
||||||
|
}).
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
completeCtx, _ := gin.CreateTestContext(recorder)
|
||||||
|
completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
|
||||||
|
completeReq.Header.Set("Content-Type", "application/json")
|
||||||
|
completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-browser-no-adoption")})
|
||||||
|
completeCtx.Request = completeReq
|
||||||
|
|
||||||
|
handler.CompleteWeChatOAuthRegistration(completeCtx)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
responseData := decodeJSONBody(t, recorder)
|
||||||
|
require.NotEmpty(t, responseData["access_token"])
|
||||||
|
require.NotEmpty(t, responseData["refresh_token"])
|
||||||
|
|
||||||
|
userEntity, err := client.User.Query().
|
||||||
|
Where(dbuser.EmailEQ(session.ResolvedEmail)).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "wechat_user", userEntity.Username)
|
||||||
|
|
||||||
|
identity, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ("wechat"),
|
||||||
|
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
|
||||||
|
authidentity.ProviderSubjectEQ("wechat-subject-no-adoption"),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, userEntity.ID, identity.UserID)
|
||||||
|
|
||||||
|
decision, err := client.IdentityAdoptionDecision.Query().
|
||||||
|
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, decision.IdentityID)
|
||||||
|
require.Equal(t, identity.ID, *decision.IdentityID)
|
||||||
|
require.False(t, decision.AdoptDisplayName)
|
||||||
|
require.False(t, decision.AdoptAvatar)
|
||||||
|
}
|
||||||
|
|
||||||
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
|
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
|
||||||
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
||||||
originalUserInfoURL := wechatOAuthUserInfoURL
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
||||||
|
|||||||
Reference in New Issue
Block a user