From 97c9b992cbf8b658b6ef27c27fd0041893b74317 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:27:15 +0800 Subject: [PATCH] fix: require wechat unionid for oauth identity --- backend/internal/handler/auth_wechat_oauth.go | 6 +- .../handler/auth_wechat_oauth_test.go | 107 +++++++++++------- 2 files changed, 66 insertions(+), 47 deletions(-) diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 816f60fd..f0755f1f 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -193,11 +193,11 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID)) openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID)) - providerSubject := firstNonEmpty(unionid, openid) - if providerSubject == "" { - redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "") + if unionid == "" { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "") return } + providerSubject := unionid username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject)) email := wechatSyntheticEmail(providerSubject) diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 0d1df1b6..1ff80e1b 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -6,7 +6,6 @@ import ( "bytes" "context" "database/sql" - "encoding/base64" "net/http" "net/http/httptest" "net/url" @@ -122,6 +121,59 @@ func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"]) } +func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "https://app.example.com/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Contains(t, recorder.Header().Get("Location"), "#error=provider_error") + require.Contains(t, recorder.Header().Get("Location"), "error_message=wechat_missing_unionid") + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + + count, err := client.PendingAuthSession.Query().Count(context.Background()) + require.NoError(t, err) + require.Zero(t, count) +} + func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) { testCases := []struct { name string @@ -542,12 +594,7 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl userRepo := &oauthPendingFlowUserRepo{client: client} redeemRepo := repository.NewRedeemCodeRepository(client) - settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{ - values: map[string]string{ - service.SettingKeyRegistrationEnabled: "true", - service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), - }, - }, &config.Config{ + cfg := &config.Config{ JWT: config.JWTConfig{ Secret: "test-secret", ExpireHour: 1, @@ -558,25 +605,20 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl UserBalance: 0, UserConcurrency: 1, }, - }) + } + settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + }, + }, cfg) authSvc := service.NewAuthService( client, userRepo, redeemRepo, &wechatOAuthRefreshTokenCacheStub{}, - &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret", - ExpireHour: 1, - AccessTokenExpireMinutes: 60, - RefreshTokenExpireDays: 7, - }, - Default: config.DefaultConfig{ - UserBalance: 0, - UserConcurrency: 1, - }, - }, + cfg, settingSvc, nil, nil, @@ -588,33 +630,10 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl return &AuthHandler{ authService: authSvc, settingSvc: settingSvc, + cfg: cfg, }, client } -func encodedCookie(name, value string) *http.Cookie { - return &http.Cookie{ - Name: name, - Value: encodeCookieValue(value), - Path: "/", - } -} - -func findCookie(cookies []*http.Cookie, name string) *http.Cookie { - for _, cookie := range cookies { - if cookie.Name == name { - return cookie - } - } - return nil -} - -func decodeCookieValueForTest(t *testing.T, value string) string { - t.Helper() - raw, err := base64.RawURLEncoding.DecodeString(value) - require.NoError(t, err) - return string(raw) -} - func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) { t.Helper()