package repository import ( "context" "crypto/sha256" "database/sql" "encoding/hex" "errors" "fmt" "io/fs" "sort" "strings" "time" "github.com/Wei-Shaw/sub2api/migrations" ) // schemaMigrationsTableDDL 定义迁移记录表的 DDL。 // 该表用于跟踪已应用的迁移文件及其校验和。 // - filename: 迁移文件名,作为主键唯一标识每个迁移 // - checksum: 文件内容的 SHA256 哈希值,用于检测迁移文件是否被篡改 // - applied_at: 迁移应用时间戳 const schemaMigrationsTableDDL = ` CREATE TABLE IF NOT EXISTS schema_migrations ( filename TEXT PRIMARY KEY, checksum TEXT NOT NULL, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); ` const atlasSchemaRevisionsTableDDL = ` CREATE TABLE IF NOT EXISTS atlas_schema_revisions ( version TEXT PRIMARY KEY, description TEXT NOT NULL, type INTEGER NOT NULL, applied INTEGER NOT NULL DEFAULT 0, total INTEGER NOT NULL DEFAULT 0, executed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), execution_time BIGINT NOT NULL DEFAULT 0, error TEXT NULL, error_stmt TEXT NULL, hash TEXT NOT NULL DEFAULT '', partial_hashes TEXT[] NULL, operator_version TEXT NULL ); ` // migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。 // 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。 // 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。 const migrationsAdvisoryLockID int64 = 694208311321144027 const migrationsLockRetryInterval = 500 * time.Millisecond // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 // // 该函数可以在每次应用启动时安全调用: // - 已应用的迁移会被自动跳过(通过校验 filename 判断) // - 如果迁移文件内容被修改(checksum 不匹配),会返回错误 // - 使用 PostgreSQL Advisory Lock 确保多实例并发安全 // // 参数: // - ctx: 上下文,用于超时控制和取消 // - db: 数据库连接 // // 返回: // - error: 迁移过程中的任何错误 func ApplyMigrations(ctx context.Context, db *sql.DB) error { if db == nil { return errors.New("nil sql db") } return applyMigrationsFS(ctx, db, migrations.FS) } // applyMigrationsFS 是迁移执行的核心实现。 // 它从指定的文件系统读取 SQL 迁移文件并按顺序应用。 // // 迁移执行流程: // 1. 获取 PostgreSQL Advisory Lock,防止多实例并发迁移 // 2. 确保 schema_migrations 表存在 // 3. 按文件名排序读取所有 .sql 文件 // 4. 对于每个迁移文件: // - 计算文件内容的 SHA256 校验和 // - 检查该迁移是否已应用(通过 filename 查询) // - 如果已应用,验证校验和是否匹配 // - 如果未应用,在事务中执行迁移并记录 // 5. 释放 Advisory Lock // // 参数: // - ctx: 上下文 // - db: 数据库连接 // - fsys: 包含迁移文件的文件系统(通常是 embed.FS) func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { if db == nil { return errors.New("nil sql db") } // 获取分布式锁,确保多实例部署时只有一个实例执行迁移。 // 这是 PostgreSQL 特有的 Advisory Lock 机制。 if err := pgAdvisoryLock(ctx, db); err != nil { return err } defer func() { // 无论迁移是否成功,都要释放锁。 // 使用 context.Background() 确保即使原 ctx 已取消也能释放锁。 _ = pgAdvisoryUnlock(context.Background(), db) }() // 创建迁移记录表(如果不存在)。 // 该表记录所有已应用的迁移及其校验和。 if _, err := db.ExecContext(ctx, schemaMigrationsTableDDL); err != nil { return fmt.Errorf("create schema_migrations: %w", err) } // 自动对齐 Atlas 基线(如果检测到 legacy schema_migrations 且缺失 atlas_schema_revisions)。 if err := ensureAtlasBaselineAligned(ctx, db, fsys); err != nil { return err } // 获取所有 .sql 迁移文件并按文件名排序。 // 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。 files, err := fs.Glob(fsys, "*.sql") if err != nil { return fmt.Errorf("list migrations: %w", err) } sort.Strings(files) // 确保按文件名顺序执行迁移 for _, name := range files { // 读取迁移文件内容 contentBytes, err := fs.ReadFile(fsys, name) if err != nil { return fmt.Errorf("read migration %s: %w", name, err) } content := strings.TrimSpace(string(contentBytes)) if content == "" { continue // 跳过空文件 } // 计算文件内容的 SHA256 校验和,用于检测文件是否被修改。 // 这是一种防篡改机制:如果有人修改了已应用的迁移文件,系统会拒绝启动。 sum := sha256.Sum256([]byte(content)) checksum := hex.EncodeToString(sum[:]) // 检查该迁移是否已经应用 var existing string rowErr := db.QueryRowContext(ctx, "SELECT checksum FROM schema_migrations WHERE filename = $1", name).Scan(&existing) if rowErr == nil { // 迁移已应用,验证校验和是否匹配 if existing != checksum { // 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。 // 正确的做法是创建新的迁移文件来进行变更。 return fmt.Errorf( "migration %s checksum mismatch (db=%s file=%s)\n"+ "This means the migration file was modified after being applied to the database.\n"+ "Solutions:\n"+ " 1. Revert to original: git log --oneline -- migrations/%s && git checkout -- migrations/%s\n"+ " 2. For new changes, create a new migration file instead of modifying existing ones\n"+ "Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments", name, existing, checksum, name, name, ) } continue // 迁移已应用且校验和匹配,跳过 } if !errors.Is(rowErr, sql.ErrNoRows) { return fmt.Errorf("check migration %s: %w", name, rowErr) } // 迁移未应用,在事务中执行。 // 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。 tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin migration %s: %w", name, err) } // 执行迁移 SQL if _, err := tx.ExecContext(ctx, content); err != nil { _ = tx.Rollback() return fmt.Errorf("apply migration %s: %w", name, err) } // 记录迁移已完成,保存文件名和校验和 if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil { _ = tx.Rollback() return fmt.Errorf("record migration %s: %w", name, err) } // 提交事务 if err := tx.Commit(); err != nil { _ = tx.Rollback() return fmt.Errorf("commit migration %s: %w", name, err) } } return nil } func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error { hasLegacy, err := tableExists(ctx, db, "schema_migrations") if err != nil { return fmt.Errorf("check schema_migrations: %w", err) } if !hasLegacy { return nil } hasAtlas, err := tableExists(ctx, db, "atlas_schema_revisions") if err != nil { return fmt.Errorf("check atlas_schema_revisions: %w", err) } if !hasAtlas { if _, err := db.ExecContext(ctx, atlasSchemaRevisionsTableDDL); err != nil { return fmt.Errorf("create atlas_schema_revisions: %w", err) } } var count int if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM atlas_schema_revisions").Scan(&count); err != nil { return fmt.Errorf("count atlas_schema_revisions: %w", err) } if count > 0 { return nil } version, description, hash, err := latestMigrationBaseline(fsys) if err != nil { return fmt.Errorf("atlas baseline version: %w", err) } if _, err := db.ExecContext(ctx, ` INSERT INTO atlas_schema_revisions (version, description, type, applied, total, executed_at, execution_time, hash) VALUES ($1, $2, $3, 0, 0, NOW(), 0, $4) `, version, description, 1, hash); err != nil { return fmt.Errorf("insert atlas baseline: %w", err) } return nil } func tableExists(ctx context.Context, db *sql.DB, tableName string) (bool, error) { var exists bool err := db.QueryRowContext(ctx, ` SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1 ) `, tableName).Scan(&exists) return exists, err } func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) { files, err := fs.Glob(fsys, "*.sql") if err != nil { return "", "", "", err } if len(files) == 0 { return "baseline", "baseline", "", nil } sort.Strings(files) name := files[len(files)-1] contentBytes, err := fs.ReadFile(fsys, name) if err != nil { return "", "", "", err } content := strings.TrimSpace(string(contentBytes)) sum := sha256.Sum256([]byte(content)) hash := hex.EncodeToString(sum[:]) version := strings.TrimSuffix(name, ".sql") return version, version, hash, nil } // pgAdvisoryLock 获取 PostgreSQL Advisory Lock。 // Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。 // 它非常适合用于应用层面的分布式锁场景,如迁移序列化。 func pgAdvisoryLock(ctx context.Context, db *sql.DB) error { ticker := time.NewTicker(migrationsLockRetryInterval) defer ticker.Stop() for { var locked bool if err := db.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", migrationsAdvisoryLockID).Scan(&locked); err != nil { return fmt.Errorf("acquire migrations lock: %w", err) } if locked { return nil } select { case <-ctx.Done(): return fmt.Errorf("acquire migrations lock: %w", ctx.Err()) case <-ticker.C: } } } // pgAdvisoryUnlock 释放 PostgreSQL Advisory Lock。 // 必须在获取锁后确保释放,否则会阻塞其他实例的迁移操作。 func pgAdvisoryUnlock(ctx context.Context, db *sql.DB) error { _, err := db.ExecContext(ctx, "SELECT pg_advisory_unlock($1)", migrationsAdvisoryLockID) if err != nil { return fmt.Errorf("release migrations lock: %w", err) } return nil }