fix(auth): preserve backward-compatible oauth defaults

This commit is contained in:
IanShaw027
2026-04-22 11:17:32 +08:00
parent dd314c41e3
commit 84628108fc
18 changed files with 661 additions and 142 deletions

View File

@@ -653,20 +653,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode)
}
}
if req.WeChatConnectRedirectURL == "" {
response.BadRequest(c, "WeChat Redirect URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL")
return
}
if req.WeChatConnectFrontendRedirectURL == "" {
req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback"
}
if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil {
response.BadRequest(c, "WeChat Frontend Redirect URL is invalid")
return
if req.WeChatConnectOpenEnabled || req.WeChatConnectMPEnabled {
if req.WeChatConnectRedirectURL == "" {
response.BadRequest(c, "WeChat Redirect URL is required when web oauth is enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL")
return
}
if req.WeChatConnectFrontendRedirectURL == "" {
req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback"
}
if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil {
response.BadRequest(c, "WeChat Frontend Redirect URL is invalid")
return
}
}
}
@@ -749,14 +751,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC scopes must contain openid")
return
}
if !req.OIDCConnectUsePKCE {
response.BadRequest(c, "OIDC PKCE must be enabled")
return
}
if !req.OIDCConnectValidateIDToken {
response.BadRequest(c, "OIDC ID Token validation must be enabled")
return
}
switch req.OIDCConnectTokenAuthMethod {
case "", "client_secret_post", "client_secret_basic", "none":
default:
@@ -767,7 +761,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return
}
if req.OIDCConnectAllowedSigningAlgs == "" {
if req.OIDCConnectValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" {
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
return
}

View File

@@ -123,13 +123,16 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
}
verifier, err := oauth.GenerateCodeVerifier()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
return
codeChallenge := ""
if cfg.UsePKCE {
verifier, err := oauth.GenerateCodeVerifier()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
return
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
}
codeChallenge := oauth.GenerateCodeChallenge(verifier)
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
@@ -200,10 +203,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName)
intent = normalizeOAuthIntent(intent)
codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
codeVerifier := ""
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
@@ -292,25 +298,16 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
if existingIdentityUser != nil {
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identityKey,
TargetUserID: &user.ID,
TargetUserID: &existingIdentityUser.ID,
ResolvedEmail: existingIdentityUser.Email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
"redirect": redirectTo,
"redirect": redirectTo,
},
}); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
@@ -546,7 +543,9 @@ func linuxDoExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
form.Set("code_verifier", codeVerifier)
if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
@@ -699,8 +698,10 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil

View File

@@ -171,6 +171,80 @@ func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
require.Equal(t, int64(42), userID)
}
func TestLinuxDoOAuthStartOmitsPKCEWhenDisabled(t *testing.T) {
handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: "https://connect.linux.do/oauth/authorize",
TokenURL: "https://connect.linux.do/oauth/token",
UserInfoURL: "https://connect.linux.do/api/user",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?redirect=/dashboard", nil)
handler.LinuxDoOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.NotContains(t, recorder.Header().Get("Location"), "code_challenge=")
require.Nil(t, findCookie(recorder.Result().Cookies(), linuxDoOAuthVerifierCookie))
}
func TestLinuxDoOAuthCallbackAllowsMissingVerifierWhenPKCEDisabled(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
require.NoError(t, r.ParseForm())
require.Empty(t, r.PostForm.Get("code_verifier"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"compat-subject","username":"linuxdo_user","name":"LinuxDo Display"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=linuxdo-code&state=state-123", nil)
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.LinuxDoOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
}
func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) {
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
@@ -327,7 +401,10 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, "/dashboard", completion["redirect"])
require.NotEmpty(t, completion["access_token"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
_, hasRefreshToken := completion["refresh_token"]
require.False(t, hasRefreshToken)
require.Nil(t, completion["error"])
}

View File

@@ -157,21 +157,25 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
}
codeChallenge := ""
verifier, genErr := oauth.GenerateCodeVerifier()
if genErr != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
return
if cfg.UsePKCE {
verifier, genErr := oauth.GenerateCodeVerifier()
if genErr != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
return
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
nonce := ""
nonce, err = oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
return
if cfg.ValidateIDToken {
nonce, err = oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
return
}
oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
}
oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
@@ -244,17 +248,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
intent = normalizeOAuthIntent(intent)
codeVerifier := ""
codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
expectedNonce := ""
expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
if expectedNonce == "" {
redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
return
if cfg.ValidateIDToken {
expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
if expectedNonce == "" {
redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
@@ -284,16 +292,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
if strings.TrimSpace(tokenResp.IDToken) == "" {
redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
return
}
var idClaims *oidcIDTokenClaims
if cfg.ValidateIDToken {
if strings.TrimSpace(tokenResp.IDToken) == "" {
redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
return
}
idClaims, err := oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
if err != nil {
log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
return
idClaims, err = oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
if err != nil {
log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
return
}
}
userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp)
@@ -303,7 +314,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
subject := strings.TrimSpace(idClaims.Subject)
subject := ""
if idClaims != nil {
subject = strings.TrimSpace(idClaims.Subject)
}
if subject == "" {
subject = strings.TrimSpace(userInfoClaims.Subject)
}
@@ -311,7 +325,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "")
return
}
issuer := strings.TrimSpace(idClaims.Issuer)
issuer := ""
if idClaims != nil {
issuer = strings.TrimSpace(idClaims.Issuer)
}
if issuer == "" {
issuer = strings.TrimSpace(cfg.IssuerURL)
}
@@ -321,21 +338,34 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
emailVerified := userInfoClaims.EmailVerified
if emailVerified == nil {
if emailVerified == nil && idClaims != nil {
emailVerified = idClaims.EmailVerified
}
if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
if idClaims != nil && userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
return
}
identityKey := oidcIdentityKey(issuer, subject)
compatEmail := strings.TrimSpace(firstNonEmpty(userInfoClaims.Email, idClaims.Email))
compatEmail := strings.TrimSpace(userInfoClaims.Email)
if compatEmail == "" && idClaims != nil {
compatEmail = strings.TrimSpace(idClaims.Email)
}
email := oidcSyntheticEmailFromIdentityKey(identityKey)
username := firstNonEmpty(
userInfoClaims.Username,
idClaims.PreferredUsername,
idClaims.Name,
func() string {
if idClaims != nil {
return idClaims.PreferredUsername
}
return ""
}(),
func() string {
if idClaims != nil {
return idClaims.Name
}
return ""
}(),
oidcFallbackUsername(subject),
)
identityRef := service.PendingAuthIdentityKey{
@@ -350,7 +380,12 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
"issuer": issuer,
"email_verified": emailVerified != nil && *emailVerified,
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string {
if idClaims != nil {
return idClaims.Name
}
return ""
}(), username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
}
if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
@@ -387,25 +422,16 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
if existingIdentityUser != nil {
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identityRef,
TargetUserID: &user.ID,
TargetUserID: &existingIdentityUser.ID,
ResolvedEmail: existingIdentityUser.Email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
"redirect": redirectTo,
"redirect": redirectTo,
},
}); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
@@ -670,7 +696,9 @@ func oidcExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
form.Set("code_verifier", codeVerifier)
if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
@@ -872,9 +900,13 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
q.Set("nonce", nonce)
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
if strings.TrimSpace(nonce) != "" {
q.Set("nonce", nonce)
}
if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil

View File

@@ -186,6 +186,89 @@ func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
require.Equal(t, int64(84), userID)
}
func TestOIDCOAuthStartOmitsPKCEAndNonceWhenDisabled(t *testing.T) {
handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{
Enabled: true,
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/oauth/authorize",
TokenURL: "https://issuer.example.com/oauth/token",
UserInfoURL: "https://issuer.example.com/oauth/userinfo",
Scopes: "openid profile email",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
ValidateIDToken: false,
RequireEmailVerified: false,
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/start?redirect=/dashboard", nil)
handler.OIDCOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.NotContains(t, location, "code_challenge=")
require.NotContains(t, location, "nonce=")
require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthVerifierCookie))
require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthNonceCookie))
}
func TestOIDCOAuthCallbackAllowsOptionalPKCEAndIDTokenValidation(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
require.NoError(t, r.ParseForm())
require.Empty(t, r.PostForm.Get("code_verifier"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"oidc-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"sub":"oidc-subject-compat","preferred_username":"oidc_user","name":"OIDC Display","email":"oidc@example.com"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newOIDCOAuthHandlerAndClient(t, false, config.OIDCConnectConfig{
Enabled: true,
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "openid profile email",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
ValidateIDToken: false,
RequireEmailVerified: false,
})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil)
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.OIDCOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
}
func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-subject-login",
@@ -250,7 +333,10 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *t
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, "/dashboard", completion["redirect"])
require.NotEmpty(t, completion["access_token"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
_, hasRefreshToken := completion["refresh_token"]
require.False(t, hasRefreshToken)
require.Nil(t, completion["error"])
}

View File

@@ -279,12 +279,7 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil {
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, nil, nil, &existingIdentityUser.ID); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}

View File

@@ -213,6 +213,86 @@ func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMo
require.Equal(t, "third_party_signup", completion["choice_reason"])
}
func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
defer client.Close()
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(wechatSyntheticEmail("union-456")).
SetUsername("wechat-existing-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("union-456").
SetMetadata(map[string]any{"username": "wechat-existing-user"}).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail)
completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.Equal(t, "/dashboard", completion["redirect"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
_, hasRefreshToken := completion["refresh_token"]
require.False(t, hasRefreshToken)
}
func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
t.Cleanup(func() {