Add auth identity legacy backfill and email sync

This commit is contained in:
IanShaw027
2026-04-21 00:13:40 +08:00
parent 9204145746
commit 5d58c7c6fb
8 changed files with 753 additions and 0 deletions

View File

@@ -0,0 +1,206 @@
//go:build integration
package repository
import (
"context"
"os"
"path/filepath"
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY;
`)
require.NoError(t, err)
var linuxDoUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoUserID))
var wechatUnionUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-union@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatUnionUserID))
var wechatOpenIDOnlyUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-openid@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatOpenIDOnlyUserID))
var syntheticAuthIdentityID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
VALUES ($1, 'wechat', 'wechat-main', 'openid-synthetic', '{"backfill_source":"synthetic_email"}'::jsonb)
RETURNING id`, wechatOpenIDOnlyUserID).Scan(&syntheticAuthIdentityID))
var linuxDoLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-user-1', NULL, 'linux-user', 'Linux User', '{"source":"legacy"}')
RETURNING id
`, linuxDoUserID).Scan(&linuxDoLegacyID))
var wechatUnionLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-union-1', 'union-1', 'wechat-union-user', 'WeChat Union User', '{"channel":"oa","appid":"wx-app-1"}')
RETURNING id
`, wechatUnionUserID).Scan(&wechatUnionLegacyID))
var wechatOpenIDLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-only-1', NULL, 'wechat-openid-user', 'WeChat OpenID User', '{"channel":"oa","appid":"wx-app-2"}')
RETURNING id
`, wechatOpenIDOnlyUserID).Scan(&wechatOpenIDLegacyID))
_, err = tx.ExecContext(ctx, string(migrationSQL))
require.NoError(t, err)
var linuxDoCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE user_id = $1
AND provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-user-1'
`, linuxDoUserID).Scan(&linuxDoCount))
require.Equal(t, 1, linuxDoCount)
var wechatSubject string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT provider_subject
FROM auth_identities
WHERE user_id = $1
AND provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND provider_subject = 'union-1'
`, wechatUnionUserID).Scan(&wechatSubject))
require.Equal(t, "union-1", wechatSubject)
var wechatChannelCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_channels channel
JOIN auth_identities ai ON ai.id = channel.identity_id
WHERE ai.user_id = $1
AND channel.provider_type = 'wechat'
AND channel.provider_key = 'wechat-main'
AND channel.channel = 'oa'
AND channel.channel_app_id = 'wx-app-1'
AND channel.channel_subject = 'openid-union-1'
`, wechatUnionUserID).Scan(&wechatChannelCount))
require.Equal(t, 1, wechatChannelCount)
var legacyOpenIDOnlyReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'wechat_openid_only_requires_remediation'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDLegacyID, 10)).Scan(&legacyOpenIDOnlyReportCount))
require.Equal(t, 1, legacyOpenIDOnlyReportCount)
var syntheticReviewCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'wechat_openid_only_requires_remediation'
AND report_key = $1
`, "synthetic_auth_identity:"+strconv.FormatInt(syntheticAuthIdentityID, 10)).Scan(&syntheticReviewCount))
require.Equal(t, 1, syntheticReviewCount)
var unionLegacyReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'wechat_openid_only_requires_remediation'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(wechatUnionLegacyID, 10)).Scan(&unionLegacyReportCount))
require.Zero(t, unionLegacyReportCount)
require.NotZero(t, linuxDoLegacyID)
}
func TestAuthIdentityLegacyExternalBackfillMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
var beforeCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
`).Scan(&beforeCount))
_, err = tx.ExecContext(ctx, string(migrationSQL))
require.NoError(t, err)
var afterCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
`).Scan(&afterCount))
require.Equal(t, beforeCount, afterCount)
}

View File

@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
@@ -76,6 +77,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
@@ -150,6 +154,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
// 已处于外部事务中ErrTxStarted复用当前 client 并由调用方负责提交/回滚。
txClient = r.client
}
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
oldEmail := existing.Email
updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
@@ -185,6 +194,9 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
@@ -196,6 +208,96 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil
}
func (r *userRepository) EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error {
return ensureEmailAuthIdentityWithClient(ctx, r.client, userID, email, "service_dual_write")
}
func (r *userRepository) ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error {
return replaceEmailAuthIdentityWithClient(ctx, r.client, userID, oldEmail, newEmail, "service_dual_write")
}
func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
client = clientFromContext(ctx, client)
if client == nil || userID <= 0 {
return nil
}
subject := normalizeEmailAuthIdentitySubject(email)
if subject == "" {
return nil
}
if err := client.AuthIdentity.Create().
SetUserID(userID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject(subject).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{"source": source}).
OnConflictColumns(
authidentity.FieldProviderType,
authidentity.FieldProviderKey,
authidentity.FieldProviderSubject,
).
DoNothing().
Exec(ctx); err != nil {
return err
}
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(subject),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil
}
return err
}
if identity.UserID != userID {
return ErrAuthIdentityOwnershipConflict
}
return nil
}
func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error {
newSubject := normalizeEmailAuthIdentitySubject(newEmail)
if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil {
return err
}
oldSubject := normalizeEmailAuthIdentitySubject(oldEmail)
if oldSubject == "" || oldSubject == newSubject {
return nil
}
_, err := clientFromContext(ctx, client).AuthIdentity.Delete().
Where(
authidentity.UserIDEQ(userID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(oldSubject),
).
Exec(ctx)
return err
}
func normalizeEmailAuthIdentitySubject(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" {
return ""
}
if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
return ""
}
return normalized
}
func (r *userRepository) Delete(ctx context.Context, id int64) error {
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {

View File

@@ -0,0 +1,86 @@
//go:build integration
package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *UserRepoSuite) TestCreate_CreatesEmailAuthIdentityForNormalEmail() {
user := &service.User{
Email: "repo-create@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 2,
}
s.Require().NoError(s.repo.Create(s.ctx, user))
identity, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("repo-create@example.com"),
).
Only(s.ctx)
s.Require().NoError(err)
s.Require().Equal(user.ID, identity.UserID)
}
func (s *UserRepoSuite) TestCreate_SkipsEmailAuthIdentityForSyntheticLinuxDoEmail() {
user := &service.User{
Email: "linuxdo-legacy-user@linuxdo-connect.invalid",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 2,
}
s.Require().NoError(s.repo.Create(s.ctx, user))
count, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
).
Count(s.ctx)
s.Require().NoError(err)
s.Require().Zero(count)
}
func (s *UserRepoSuite) TestUpdate_ReplacesEmailAuthIdentityWhenEmailChanges() {
user := s.mustCreateUser(&service.User{
Email: "before-update@example.com",
})
user.Email = "after-update@example.com"
s.Require().NoError(s.repo.Update(s.ctx, user))
newIdentity, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("after-update@example.com"),
).
Only(s.ctx)
s.Require().NoError(err)
s.Require().Equal(user.ID, newIdentity.UserID)
oldCount, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("before-update@example.com"),
).
Count(context.Background())
s.Require().NoError(err)
s.Require().Zero(oldCount)
}

View File

@@ -26,6 +26,8 @@ func (s *UserRepoSuite) SetupTest() {
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
// 清理测试数据,确保每个测试从干净状态开始
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identity_channels")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identities")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")