fix(auth): preserve backward-compatible oauth defaults
This commit is contained in:
@@ -653,20 +653,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode)
|
||||
}
|
||||
}
|
||||
if req.WeChatConnectRedirectURL == "" {
|
||||
response.BadRequest(c, "WeChat Redirect URL is required when enabled")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
|
||||
response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
if req.WeChatConnectFrontendRedirectURL == "" {
|
||||
req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback"
|
||||
}
|
||||
if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil {
|
||||
response.BadRequest(c, "WeChat Frontend Redirect URL is invalid")
|
||||
return
|
||||
if req.WeChatConnectOpenEnabled || req.WeChatConnectMPEnabled {
|
||||
if req.WeChatConnectRedirectURL == "" {
|
||||
response.BadRequest(c, "WeChat Redirect URL is required when web oauth is enabled")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
|
||||
response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
if req.WeChatConnectFrontendRedirectURL == "" {
|
||||
req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback"
|
||||
}
|
||||
if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil {
|
||||
response.BadRequest(c, "WeChat Frontend Redirect URL is invalid")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -749,14 +751,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.BadRequest(c, "OIDC scopes must contain openid")
|
||||
return
|
||||
}
|
||||
if !req.OIDCConnectUsePKCE {
|
||||
response.BadRequest(c, "OIDC PKCE must be enabled")
|
||||
return
|
||||
}
|
||||
if !req.OIDCConnectValidateIDToken {
|
||||
response.BadRequest(c, "OIDC ID Token validation must be enabled")
|
||||
return
|
||||
}
|
||||
switch req.OIDCConnectTokenAuthMethod {
|
||||
case "", "client_secret_post", "client_secret_basic", "none":
|
||||
default:
|
||||
@@ -767,7 +761,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
|
||||
return
|
||||
}
|
||||
if req.OIDCConnectAllowedSigningAlgs == "" {
|
||||
if req.OIDCConnectValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" {
|
||||
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -123,13 +123,16 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
|
||||
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
|
||||
}
|
||||
|
||||
verifier, err := oauth.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
|
||||
return
|
||||
codeChallenge := ""
|
||||
if cfg.UsePKCE {
|
||||
verifier, err := oauth.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
|
||||
return
|
||||
}
|
||||
codeChallenge = oauth.GenerateCodeChallenge(verifier)
|
||||
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||
}
|
||||
codeChallenge := oauth.GenerateCodeChallenge(verifier)
|
||||
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
|
||||
|
||||
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||
if redirectURI == "" {
|
||||
@@ -200,10 +203,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName)
|
||||
intent = normalizeOAuthIntent(intent)
|
||||
|
||||
codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie)
|
||||
if codeVerifier == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
|
||||
return
|
||||
codeVerifier := ""
|
||||
if cfg.UsePKCE {
|
||||
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
|
||||
if codeVerifier == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||
@@ -292,25 +298,16 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if existingIdentityUser != nil {
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: oauthIntentLogin,
|
||||
Identity: identityKey,
|
||||
TargetUserID: &user.ID,
|
||||
TargetUserID: &existingIdentityUser.ID,
|
||||
ResolvedEmail: existingIdentityUser.Email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: map[string]any{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
"redirect": redirectTo,
|
||||
"redirect": redirectTo,
|
||||
},
|
||||
}); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
|
||||
@@ -546,7 +543,9 @@ func linuxDoExchangeCode(
|
||||
form.Set("client_id", cfg.ClientID)
|
||||
form.Set("code", code)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
form.Set("code_verifier", codeVerifier)
|
||||
if strings.TrimSpace(codeVerifier) != "" {
|
||||
form.Set("code_verifier", codeVerifier)
|
||||
}
|
||||
|
||||
r := client.R().
|
||||
SetContext(ctx).
|
||||
@@ -699,8 +698,10 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod
|
||||
q.Set("scope", cfg.Scopes)
|
||||
}
|
||||
q.Set("state", state)
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
if strings.TrimSpace(codeChallenge) != "" {
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
|
||||
@@ -171,6 +171,80 @@ func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
|
||||
require.Equal(t, int64(42), userID)
|
||||
}
|
||||
|
||||
func TestLinuxDoOAuthStartOmitsPKCEWhenDisabled(t *testing.T) {
|
||||
handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
|
||||
Enabled: true,
|
||||
ClientID: "linuxdo-client",
|
||||
ClientSecret: "linuxdo-secret",
|
||||
AuthorizeURL: "https://connect.linux.do/oauth/authorize",
|
||||
TokenURL: "https://connect.linux.do/oauth/token",
|
||||
UserInfoURL: "https://connect.linux.do/api/user",
|
||||
Scopes: "read",
|
||||
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
|
||||
FrontendRedirectURL: "/auth/linuxdo/callback",
|
||||
TokenAuthMethod: "client_secret_post",
|
||||
UsePKCE: false,
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?redirect=/dashboard", nil)
|
||||
|
||||
handler.LinuxDoOAuthStart(c)
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
require.NotContains(t, recorder.Header().Get("Location"), "code_challenge=")
|
||||
require.Nil(t, findCookie(recorder.Result().Cookies(), linuxDoOAuthVerifierCookie))
|
||||
}
|
||||
|
||||
func TestLinuxDoOAuthCallbackAllowsMissingVerifierWhenPKCEDisabled(t *testing.T) {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/token":
|
||||
require.NoError(t, r.ParseForm())
|
||||
require.Empty(t, r.PostForm.Get("code_verifier"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
|
||||
case "/userinfo":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"compat-subject","username":"linuxdo_user","name":"LinuxDo Display"}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
|
||||
Enabled: true,
|
||||
ClientID: "linuxdo-client",
|
||||
ClientSecret: "linuxdo-secret",
|
||||
AuthorizeURL: upstream.URL + "/authorize",
|
||||
TokenURL: upstream.URL + "/token",
|
||||
UserInfoURL: upstream.URL + "/userinfo",
|
||||
Scopes: "read",
|
||||
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
|
||||
FrontendRedirectURL: "/auth/linuxdo/callback",
|
||||
TokenAuthMethod: "client_secret_post",
|
||||
UsePKCE: false,
|
||||
})
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=linuxdo-code&state=state-123", nil)
|
||||
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
|
||||
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
|
||||
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
|
||||
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
||||
c.Request = req
|
||||
|
||||
handler.LinuxDoOAuthCallback(c)
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
|
||||
require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
||||
}
|
||||
|
||||
func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) {
|
||||
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
|
||||
Enabled: true,
|
||||
@@ -327,7 +401,10 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t
|
||||
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "/dashboard", completion["redirect"])
|
||||
require.NotEmpty(t, completion["access_token"])
|
||||
_, hasAccessToken := completion["access_token"]
|
||||
require.False(t, hasAccessToken)
|
||||
_, hasRefreshToken := completion["refresh_token"]
|
||||
require.False(t, hasRefreshToken)
|
||||
require.Nil(t, completion["error"])
|
||||
}
|
||||
|
||||
|
||||
@@ -157,21 +157,25 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
|
||||
}
|
||||
|
||||
codeChallenge := ""
|
||||
verifier, genErr := oauth.GenerateCodeVerifier()
|
||||
if genErr != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
|
||||
return
|
||||
if cfg.UsePKCE {
|
||||
verifier, genErr := oauth.GenerateCodeVerifier()
|
||||
if genErr != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
|
||||
return
|
||||
}
|
||||
codeChallenge = oauth.GenerateCodeChallenge(verifier)
|
||||
oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
|
||||
}
|
||||
codeChallenge = oauth.GenerateCodeChallenge(verifier)
|
||||
oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
|
||||
|
||||
nonce := ""
|
||||
nonce, err = oauth.GenerateState()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
|
||||
return
|
||||
if cfg.ValidateIDToken {
|
||||
nonce, err = oauth.GenerateState()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
|
||||
return
|
||||
}
|
||||
oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
|
||||
}
|
||||
oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
|
||||
|
||||
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||
if redirectURI == "" {
|
||||
@@ -244,17 +248,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
intent = normalizeOAuthIntent(intent)
|
||||
|
||||
codeVerifier := ""
|
||||
codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
|
||||
if codeVerifier == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
|
||||
return
|
||||
if cfg.UsePKCE {
|
||||
codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
|
||||
if codeVerifier == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
expectedNonce := ""
|
||||
expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
|
||||
if expectedNonce == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
|
||||
return
|
||||
if cfg.ValidateIDToken {
|
||||
expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
|
||||
if expectedNonce == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
redirectURI := strings.TrimSpace(cfg.RedirectURL)
|
||||
@@ -284,16 +292,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(tokenResp.IDToken) == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
|
||||
return
|
||||
}
|
||||
var idClaims *oidcIDTokenClaims
|
||||
if cfg.ValidateIDToken {
|
||||
if strings.TrimSpace(tokenResp.IDToken) == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
|
||||
return
|
||||
}
|
||||
|
||||
idClaims, err := oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
|
||||
if err != nil {
|
||||
log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
|
||||
redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
|
||||
return
|
||||
idClaims, err = oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
|
||||
if err != nil {
|
||||
log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
|
||||
redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp)
|
||||
@@ -303,7 +314,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
subject := strings.TrimSpace(idClaims.Subject)
|
||||
subject := ""
|
||||
if idClaims != nil {
|
||||
subject = strings.TrimSpace(idClaims.Subject)
|
||||
}
|
||||
if subject == "" {
|
||||
subject = strings.TrimSpace(userInfoClaims.Subject)
|
||||
}
|
||||
@@ -311,7 +325,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "")
|
||||
return
|
||||
}
|
||||
issuer := strings.TrimSpace(idClaims.Issuer)
|
||||
issuer := ""
|
||||
if idClaims != nil {
|
||||
issuer = strings.TrimSpace(idClaims.Issuer)
|
||||
}
|
||||
if issuer == "" {
|
||||
issuer = strings.TrimSpace(cfg.IssuerURL)
|
||||
}
|
||||
@@ -321,21 +338,34 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
}
|
||||
|
||||
emailVerified := userInfoClaims.EmailVerified
|
||||
if emailVerified == nil {
|
||||
if emailVerified == nil && idClaims != nil {
|
||||
emailVerified = idClaims.EmailVerified
|
||||
}
|
||||
if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
|
||||
if idClaims != nil && userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
|
||||
redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
|
||||
return
|
||||
}
|
||||
|
||||
identityKey := oidcIdentityKey(issuer, subject)
|
||||
compatEmail := strings.TrimSpace(firstNonEmpty(userInfoClaims.Email, idClaims.Email))
|
||||
compatEmail := strings.TrimSpace(userInfoClaims.Email)
|
||||
if compatEmail == "" && idClaims != nil {
|
||||
compatEmail = strings.TrimSpace(idClaims.Email)
|
||||
}
|
||||
email := oidcSyntheticEmailFromIdentityKey(identityKey)
|
||||
username := firstNonEmpty(
|
||||
userInfoClaims.Username,
|
||||
idClaims.PreferredUsername,
|
||||
idClaims.Name,
|
||||
func() string {
|
||||
if idClaims != nil {
|
||||
return idClaims.PreferredUsername
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
func() string {
|
||||
if idClaims != nil {
|
||||
return idClaims.Name
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
oidcFallbackUsername(subject),
|
||||
)
|
||||
identityRef := service.PendingAuthIdentityKey{
|
||||
@@ -350,7 +380,12 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
"issuer": issuer,
|
||||
"email_verified": emailVerified != nil && *emailVerified,
|
||||
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
|
||||
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
|
||||
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string {
|
||||
if idClaims != nil {
|
||||
return idClaims.Name
|
||||
}
|
||||
return ""
|
||||
}(), username),
|
||||
"suggested_avatar_url": userInfoClaims.AvatarURL,
|
||||
}
|
||||
if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
|
||||
@@ -387,25 +422,16 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if existingIdentityUser != nil {
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: oauthIntentLogin,
|
||||
Identity: identityRef,
|
||||
TargetUserID: &user.ID,
|
||||
TargetUserID: &existingIdentityUser.ID,
|
||||
ResolvedEmail: existingIdentityUser.Email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: map[string]any{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
"redirect": redirectTo,
|
||||
"redirect": redirectTo,
|
||||
},
|
||||
}); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
|
||||
@@ -670,7 +696,9 @@ func oidcExchangeCode(
|
||||
form.Set("client_id", cfg.ClientID)
|
||||
form.Set("code", code)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
form.Set("code_verifier", codeVerifier)
|
||||
if strings.TrimSpace(codeVerifier) != "" {
|
||||
form.Set("code_verifier", codeVerifier)
|
||||
}
|
||||
|
||||
r := client.R().
|
||||
SetContext(ctx).
|
||||
@@ -872,9 +900,13 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall
|
||||
q.Set("scope", cfg.Scopes)
|
||||
}
|
||||
q.Set("state", state)
|
||||
q.Set("nonce", nonce)
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
if strings.TrimSpace(nonce) != "" {
|
||||
q.Set("nonce", nonce)
|
||||
}
|
||||
if strings.TrimSpace(codeChallenge) != "" {
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
|
||||
@@ -186,6 +186,89 @@ func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
|
||||
require.Equal(t, int64(84), userID)
|
||||
}
|
||||
|
||||
func TestOIDCOAuthStartOmitsPKCEAndNonceWhenDisabled(t *testing.T) {
|
||||
handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{
|
||||
Enabled: true,
|
||||
ClientID: "oidc-client",
|
||||
ClientSecret: "oidc-secret",
|
||||
IssuerURL: "https://issuer.example.com",
|
||||
AuthorizeURL: "https://issuer.example.com/oauth/authorize",
|
||||
TokenURL: "https://issuer.example.com/oauth/token",
|
||||
UserInfoURL: "https://issuer.example.com/oauth/userinfo",
|
||||
Scopes: "openid profile email",
|
||||
RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
|
||||
FrontendRedirectURL: "/auth/oidc/callback",
|
||||
TokenAuthMethod: "client_secret_post",
|
||||
UsePKCE: false,
|
||||
ValidateIDToken: false,
|
||||
RequireEmailVerified: false,
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/start?redirect=/dashboard", nil)
|
||||
|
||||
handler.OIDCOAuthStart(c)
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
require.NotContains(t, location, "code_challenge=")
|
||||
require.NotContains(t, location, "nonce=")
|
||||
require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthVerifierCookie))
|
||||
require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthNonceCookie))
|
||||
}
|
||||
|
||||
func TestOIDCOAuthCallbackAllowsOptionalPKCEAndIDTokenValidation(t *testing.T) {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/token":
|
||||
require.NoError(t, r.ParseForm())
|
||||
require.Empty(t, r.PostForm.Get("code_verifier"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"access_token":"oidc-access","token_type":"Bearer","expires_in":3600}`))
|
||||
case "/userinfo":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"sub":"oidc-subject-compat","preferred_username":"oidc_user","name":"OIDC Display","email":"oidc@example.com"}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
handler, client := newOIDCOAuthHandlerAndClient(t, false, config.OIDCConnectConfig{
|
||||
Enabled: true,
|
||||
ClientID: "oidc-client",
|
||||
ClientSecret: "oidc-secret",
|
||||
IssuerURL: "https://issuer.example.com",
|
||||
AuthorizeURL: upstream.URL + "/authorize",
|
||||
TokenURL: upstream.URL + "/token",
|
||||
UserInfoURL: upstream.URL + "/userinfo",
|
||||
Scopes: "openid profile email",
|
||||
RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
|
||||
FrontendRedirectURL: "/auth/oidc/callback",
|
||||
TokenAuthMethod: "client_secret_post",
|
||||
UsePKCE: false,
|
||||
ValidateIDToken: false,
|
||||
RequireEmailVerified: false,
|
||||
})
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil)
|
||||
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123"))
|
||||
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
|
||||
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
|
||||
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
||||
c.Request = req
|
||||
|
||||
handler.OIDCOAuthCallback(c)
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
|
||||
require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
||||
}
|
||||
|
||||
func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
|
||||
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
|
||||
Subject: "oidc-subject-login",
|
||||
@@ -250,7 +333,10 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *t
|
||||
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "/dashboard", completion["redirect"])
|
||||
require.NotEmpty(t, completion["access_token"])
|
||||
_, hasAccessToken := completion["access_token"]
|
||||
require.False(t, hasAccessToken)
|
||||
_, hasRefreshToken := completion["refresh_token"]
|
||||
require.False(t, hasRefreshToken)
|
||||
require.Nil(t, completion["error"])
|
||||
}
|
||||
|
||||
|
||||
@@ -279,12 +279,7 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil {
|
||||
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, nil, nil, &existingIdentityUser.ID); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -213,6 +213,86 @@ func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMo
|
||||
require.Equal(t, "third_party_signup", completion["choice_reason"])
|
||||
}
|
||||
|
||||
func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens(t *testing.T) {
|
||||
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
||||
originalUserInfoURL := wechatOAuthUserInfoURL
|
||||
t.Cleanup(func() {
|
||||
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
||||
wechatOAuthUserInfoURL = originalUserInfoURL
|
||||
})
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
|
||||
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer upstream.Close()
|
||||
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
||||
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
||||
|
||||
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
existingUser, err := client.User.Create().
|
||||
SetEmail(wechatSyntheticEmail("union-456")).
|
||||
SetUsername("wechat-existing-user").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
_, err = client.AuthIdentity.Create().
|
||||
SetUserID(existingUser.ID).
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey(wechatOAuthProviderKey).
|
||||
SetProviderSubject("union-456").
|
||||
SetMetadata(map[string]any{"username": "wechat-existing-user"}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
||||
req.Host = "api.example.com"
|
||||
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
||||
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
||||
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
||||
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
||||
c.Request = req
|
||||
|
||||
handler.WeChatOAuthCallback(c)
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
|
||||
|
||||
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
|
||||
require.NotNil(t, sessionCookie)
|
||||
|
||||
session, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, oauthIntentLogin, session.Intent)
|
||||
require.NotNil(t, session.TargetUserID)
|
||||
require.Equal(t, existingUser.ID, *session.TargetUserID)
|
||||
require.Equal(t, existingUser.Email, session.ResolvedEmail)
|
||||
|
||||
completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
|
||||
require.Equal(t, "/dashboard", completion["redirect"])
|
||||
_, hasAccessToken := completion["access_token"]
|
||||
require.False(t, hasAccessToken)
|
||||
_, hasRefreshToken := completion["refresh_token"]
|
||||
require.False(t, hasRefreshToken)
|
||||
}
|
||||
|
||||
func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
|
||||
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
||||
t.Cleanup(func() {
|
||||
|
||||
Reference in New Issue
Block a user