diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 21ed2bc6..fd35e4e5 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -606,39 +606,42 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, providerType := strings.TrimSpace(session.ProviderType) providerKey := strings.TrimSpace(session.ProviderKey) providerSubject := strings.TrimSpace(session.ProviderSubject) + providerKeys := wechatCompatibleProviderKeys(providerKey) channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel")) channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id")) channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject")) metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims) - identity, err := client.AuthIdentity.Query(). + identityRecords, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ(providerType), - authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderKeyIn(providerKeys...), authidentity.ProviderSubjectEQ(providerSubject), ). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, err } - if identity != nil && identity.UserID != userID { - return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(identityRecords, userID, providerKey) + if err != nil { + return nil, err } var legacyOpenIDIdentity *dbent.AuthIdentity if channelSubject != "" && channelSubject != providerSubject { - legacyOpenIDIdentity, err = client.AuthIdentity.Query(). + legacyOpenIDRecords, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ(providerType), - authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderKeyIn(providerKeys...), authidentity.ProviderSubjectEQ(channelSubject), ). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, err } - if legacyOpenIDIdentity != nil && legacyOpenIDIdentity.UserID != userID { - return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(legacyOpenIDRecords, userID, providerKey) + if err != nil { + return nil, err } } @@ -646,6 +649,9 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, case identity != nil: update := client.AuthIdentity.UpdateOneID(identity.ID). SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata)) + if !strings.EqualFold(strings.TrimSpace(identity.ProviderKey), providerKey) && !hasCanonicalKey { + update = update.SetProviderKey(providerKey) + } if issuer := oauthIdentityIssuer(session); issuer != nil { update = update.SetIssuer(strings.TrimSpace(*issuer)) } @@ -655,6 +661,7 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, } case legacyOpenIDIdentity != nil: update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID). + SetProviderKey(providerKey). SetProviderSubject(providerSubject). SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata)) if issuer := oauthIdentityIssuer(session); issuer != nil { @@ -684,21 +691,22 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, return identity, nil } - channelRecord, err := client.AuthIdentityChannel.Query(). + channelRecords, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ(providerType), - authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ProviderKeyIn(providerKeys...), authidentitychannel.ChannelEQ(channel), authidentitychannel.ChannelAppIDEQ(channelAppID), authidentitychannel.ChannelSubjectEQ(channelSubject), ). WithIdentity(). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, err } - if channelRecord != nil && channelRecord.Edges.Identity != nil && channelRecord.Edges.Identity.UserID != userID { - return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(channelRecords, userID, providerKey) + if err != nil { + return nil, err } channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata) @@ -717,16 +725,75 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, return identity, nil } - _, err = client.AuthIdentityChannel.UpdateOneID(channelRecord.ID). + updateChannel := client.AuthIdentityChannel.UpdateOneID(channelRecord.ID). SetIdentityID(identity.ID). - SetMetadata(channelMetadata). - Save(ctx) + SetMetadata(channelMetadata) + if !strings.EqualFold(strings.TrimSpace(channelRecord.ProviderKey), providerKey) && !hasCanonicalChannelKey { + updateChannel = updateChannel.SetProviderKey(providerKey) + } + _, err = updateChannel.Save(ctx) if err != nil { return nil, err } return identity, nil } +func chooseWeChatIdentityForUser(records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) { + var preferred *dbent.AuthIdentity + var fallback *dbent.AuthIdentity + hasCanonicalKey := false + for _, record := range records { + if record == nil { + continue + } + if record.UserID != userID { + return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) { + hasCanonicalKey = true + if preferred == nil { + preferred = record + } + continue + } + if fallback == nil { + fallback = record + } + } + if preferred != nil { + return preferred, hasCanonicalKey, nil + } + return fallback, hasCanonicalKey, nil +} + +func chooseWeChatChannelForUser(records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) { + var preferred *dbent.AuthIdentityChannel + var fallback *dbent.AuthIdentityChannel + hasCanonicalKey := false + for _, record := range records { + if record == nil { + continue + } + if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID { + return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) { + hasCanonicalKey = true + if preferred == nil { + preferred = record + } + continue + } + if fallback == nil { + fallback = record + } + } + if preferred != nil { + return preferred, hasCanonicalKey, nil + } + return fallback, hasCanonicalKey, nil +} + func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any { if channel == nil { return map[string]any{} diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index b6d47670..95993dfc 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -34,6 +34,7 @@ const ( wechatOAuthDefaultRedirectTo = "/dashboard" wechatOAuthDefaultFrontendCB = "/auth/wechat/callback" wechatOAuthProviderKey = "wechat-main" + wechatOAuthLegacyProviderKey = "wechat" wechatOAuthIntentLogin = "login" wechatOAuthIntentBind = "bind_current_user" @@ -483,18 +484,20 @@ func (h *AuthHandler) ensureWeChatBindOwnership( return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") } - identity, err := client.AuthIdentity.Query(). + identities, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("wechat"), - authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...), authidentity.ProviderSubjectEQ(strings.TrimSpace(providerSubject)), ). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return infraerrors.InternalServer("WECHAT_BIND_LOOKUP_FAILED", "failed to inspect wechat identity ownership").WithCause(err) } - if identity != nil && identity.UserID != userID { - return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + for _, identity := range identities { + if identity != nil && identity.UserID != userID { + return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } } channelSubject = strings.TrimSpace(channelSubject) @@ -503,21 +506,23 @@ func (h *AuthHandler) ensureWeChatBindOwnership( return nil } - channel, err := client.AuthIdentityChannel.Query(). + channels, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ("wechat"), - authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...), authidentitychannel.ChannelEQ(strings.TrimSpace(cfg.mode)), authidentitychannel.ChannelAppIDEQ(channelAppID), authidentitychannel.ChannelSubjectEQ(channelSubject), ). WithIdentity(). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return infraerrors.InternalServer("WECHAT_BIND_CHANNEL_LOOKUP_FAILED", "failed to inspect wechat identity channel ownership").WithCause(err) } - if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { - return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + for _, channel := range channels { + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } } return nil } @@ -533,14 +538,34 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") } + providerType := strings.TrimSpace(identity.ProviderType) + providerSubject := strings.TrimSpace(identity.ProviderSubject) + providerKeys := wechatCompatibleProviderKeys(identity.ProviderKey) + if providerSubject != "" { + records, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), + authidentity.ProviderSubjectEQ(providerSubject), + ). + WithUser(). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if user, err := singleWeChatIdentityUser(records); err != nil || user != nil { + return user, err + } + } + openid = strings.TrimSpace(openid) channel := strings.TrimSpace(cfg.mode) channelAppID := strings.TrimSpace(cfg.appID) if openid != "" && channel != "" && channelAppID != "" { - record, err := client.AuthIdentityChannel.Query(). + records, err := client.AuthIdentityChannel.Query(). Where( - authidentitychannel.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), - authidentitychannel.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyIn(providerKeys...), authidentitychannel.ChannelEQ(channel), authidentitychannel.ChannelAppIDEQ(channelAppID), authidentitychannel.ChannelSubjectEQ(openid), @@ -548,12 +573,12 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( WithIdentity(func(q *dbent.AuthIdentityQuery) { q.WithUser() }). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) } - if record != nil && record.Edges.Identity != nil && record.Edges.Identity.Edges.User != nil { - return record.Edges.Identity.Edges.User, nil + if user, err := singleWeChatChannelUser(records); err != nil || user != nil { + return user, err } } @@ -561,21 +586,64 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( return nil, nil } - record, err := client.AuthIdentity.Query(). + records, err := client.AuthIdentity.Query(). Where( - authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), - authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), authidentity.ProviderSubjectEQ(openid), ). WithUser(). - Only(ctx) + All(ctx) if err != nil { - if dbent.IsNotFound(err) { - return nil, nil - } return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) } - return record.Edges.User, nil + return singleWeChatIdentityUser(records) +} + +func wechatCompatibleProviderKeys(providerKey string) []string { + preferred := strings.TrimSpace(providerKey) + if preferred == "" { + preferred = wechatOAuthProviderKey + } + keys := []string{preferred} + if !strings.EqualFold(preferred, wechatOAuthLegacyProviderKey) { + keys = append(keys, wechatOAuthLegacyProviderKey) + } + return keys +} + +func singleWeChatIdentityUser(records []*dbent.AuthIdentity) (*dbent.User, error) { + var resolved *dbent.User + for _, record := range records { + if record == nil || record.Edges.User == nil { + continue + } + if resolved == nil { + resolved = record.Edges.User + continue + } + if resolved.ID != record.Edges.User.ID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + return resolved, nil +} + +func singleWeChatChannelUser(records []*dbent.AuthIdentityChannel) (*dbent.User, error) { + var resolved *dbent.User + for _, record := range records { + if record == nil || record.Edges.Identity == nil || record.Edges.Identity.Edges.User == nil { + continue + } + if resolved == nil { + resolved = record.Edges.Identity.Edges.User + continue + } + if resolved.ID != record.Edges.Identity.Edges.User.ID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + } + return resolved, nil } func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding( diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index dd022fb9..def9d5d6 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -467,6 +467,88 @@ func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) { require.Zero(t, count) } +func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthLegacyProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") @@ -703,6 +785,116 @@ func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { require.Equal(t, repairedIdentity.ID, channel.IdentityID) } +func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + legacyUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(legacyUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthLegacyProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, session.TargetUserID) + require.Equal(t, legacyUser.ID, *session.TargetUserID) + require.Equal(t, legacyUser.Email, session.ResolvedEmail) + + repairedIdentity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) + require.Equal(t, legacyUser.ID, repairedIdentity.UserID) + + legacyIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, legacyIdentityCount) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, repairedIdentity.ID, channel.IdentityID) +} + func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper() diff --git a/backend/migrations/113_normalize_legacy_wechat_provider_key.sql b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql new file mode 100644 index 00000000..15610af0 --- /dev/null +++ b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql @@ -0,0 +1,89 @@ +UPDATE auth_identities AS ai +SET + provider_key = 'wechat-main', + metadata = COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object( + 'legacy_provider_key', 'wechat', + 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key' + ), + updated_at = NOW() +WHERE ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identities AS canon + WHERE canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.provider_subject = ai.provider_subject + ); + +UPDATE auth_identity_channels AS channel +SET + provider_key = 'wechat-main', + metadata = COALESCE(channel.metadata, '{}'::jsonb) || jsonb_build_object( + 'legacy_provider_key', 'wechat', + 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key' + ), + updated_at = NOW() +WHERE channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identity_channels AS canon + WHERE canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.channel = channel.channel + AND canon.channel_app_id = channel.channel_app_id + AND canon.channel_subject = channel.channel_subject + ); + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_provider_key_conflict', + CAST(ai.id AS TEXT), + jsonb_build_object( + 'legacy_identity_id', ai.id, + 'legacy_user_id', ai.user_id, + 'provider_subject', ai.provider_subject, + 'canonical_identity_id', canon.id, + 'canonical_user_id', canon.user_id, + 'same_user', canon.user_id = ai.user_id, + 'migration', '113_normalize_legacy_wechat_provider_key' + ) +FROM auth_identities AS ai +JOIN auth_identities AS canon + ON canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.provider_subject = ai.provider_subject +WHERE ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' +ON CONFLICT (report_type, report_key) DO NOTHING; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_channel_provider_key_conflict', + CAST(channel.id AS TEXT), + jsonb_build_object( + 'legacy_channel_id', channel.id, + 'legacy_identity_id', channel.identity_id, + 'canonical_channel_id', canon.id, + 'canonical_identity_id', canon.identity_id, + 'channel', channel.channel, + 'channel_app_id', channel.channel_app_id, + 'channel_subject', channel.channel_subject, + 'same_user', COALESCE(legacy_identity.user_id = canonical_identity.user_id, FALSE), + 'migration', '113_normalize_legacy_wechat_provider_key' + ) +FROM auth_identity_channels AS channel +JOIN auth_identity_channels AS canon + ON canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.channel = channel.channel + AND canon.channel_app_id = channel.channel_app_id + AND canon.channel_subject = channel.channel_subject +LEFT JOIN auth_identities AS legacy_identity + ON legacy_identity.id = channel.identity_id +LEFT JOIN auth_identities AS canonical_identity + ON canonical_identity.id = canon.identity_id +WHERE channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat' +ON CONFLICT (report_type, report_key) DO NOTHING;