fix: finalize oauth identity bindings

This commit is contained in:
IanShaw027
2026-04-20 21:24:33 +08:00
parent bdcd3d87e5
commit 5adefb466b
4 changed files with 376 additions and 7 deletions

View File

@@ -15,6 +15,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
@@ -563,6 +564,19 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
channel, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ProviderKeyEQ("wechat-main"),
authidentitychannel.ChannelEQ("open"),
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
authidentitychannel.ChannelSubjectEQ("openid-123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, identity.ID, channel.IdentityID)
require.Equal(t, "union-456", channel.Metadata["unionid"])
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
Only(ctx)
@@ -579,6 +593,116 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require.NotNil(t, consumed.ConsumedAt)
}
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
legacyUser, err := client.User.Create().
SetEmail("legacy@example.com").
SetUsername("legacy-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
legacyIdentity, err := client.AuthIdentity.Create().
SetUserID(legacyUser.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("openid-123").
SetMetadata(map[string]any{"openid": "openid-123"}).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, session.TargetUserID)
require.Equal(t, legacyUser.ID, *session.TargetUserID)
require.Equal(t, legacyUser.Email, session.ResolvedEmail)
repairedIdentity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
authidentity.ProviderSubjectEQ("union-456"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
openIDIdentityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
authidentity.ProviderSubjectEQ("openid-123"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, openIDIdentityCount)
channel, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
authidentitychannel.ChannelEQ("open"),
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
authidentitychannel.ChannelSubjectEQ("openid-123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, repairedIdentity.ID, channel.IdentityID)
}
func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
t.Helper()