369 lines
12 KiB
Go
369 lines
12 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"io/fs"
|
|
"strings"
|
|
"testing"
|
|
"testing/fstest"
|
|
"time"
|
|
|
|
sqlmock "github.com/DATA-DOG/go-sqlmock"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestApplyMigrations_NilDB(t *testing.T) {
|
|
err := ApplyMigrations(context.Background(), nil)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "nil sql db")
|
|
}
|
|
|
|
func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
|
WithArgs(migrationsAdvisoryLockID).
|
|
WillReturnError(errors.New("lock failed"))
|
|
|
|
err = ApplyMigrations(context.Background(), db)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "acquire migrations lock")
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
}
|
|
|
|
func TestLatestMigrationBaseline(t *testing.T) {
|
|
t.Run("empty_fs_returns_baseline", func(t *testing.T) {
|
|
version, description, hash, err := latestMigrationBaseline(fstest.MapFS{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "baseline", version)
|
|
require.Equal(t, "baseline", description)
|
|
require.Equal(t, "", hash)
|
|
})
|
|
|
|
t.Run("uses_latest_sorted_sql_file", func(t *testing.T) {
|
|
fsys := fstest.MapFS{
|
|
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
|
|
"010_final.sql": &fstest.MapFile{
|
|
Data: []byte("CREATE TABLE t2(id int);"),
|
|
},
|
|
}
|
|
version, description, hash, err := latestMigrationBaseline(fsys)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "010_final", version)
|
|
require.Equal(t, "010_final", description)
|
|
require.Len(t, hash, 64)
|
|
})
|
|
|
|
t.Run("read_file_error", func(t *testing.T) {
|
|
fsys := fstest.MapFS{
|
|
"010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
|
|
}
|
|
_, _, _, err := latestMigrationBaseline(fsys)
|
|
require.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
|
|
require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file"))
|
|
|
|
var (
|
|
name string
|
|
rule migrationChecksumCompatibilityRule
|
|
)
|
|
for n, r := range migrationChecksumCompatibilityRules {
|
|
name = n
|
|
rule = r
|
|
break
|
|
}
|
|
require.NotEmpty(t, name)
|
|
|
|
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match"))
|
|
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum))
|
|
|
|
var accepted string
|
|
for checksum := range rule.acceptedDBChecksum {
|
|
accepted = checksum
|
|
break
|
|
}
|
|
require.NotEmpty(t, accepted)
|
|
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
|
|
}
|
|
|
|
func TestEnsureAtlasBaselineAligned(t *testing.T) {
|
|
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
mock.ExpectQuery("SELECT EXISTS \\(").
|
|
WithArgs("schema_migrations").
|
|
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
|
|
|
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
|
require.NoError(t, err)
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
})
|
|
|
|
t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
mock.ExpectQuery("SELECT EXISTS \\(").
|
|
WithArgs("schema_migrations").
|
|
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
|
mock.ExpectQuery("SELECT EXISTS \\(").
|
|
WithArgs("atlas_schema_revisions").
|
|
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
|
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
|
|
WillReturnResult(sqlmock.NewResult(0, 0))
|
|
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
|
|
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
|
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
|
|
WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()).
|
|
WillReturnResult(sqlmock.NewResult(1, 1))
|
|
|
|
fsys := fstest.MapFS{
|
|
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
|
|
"002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")},
|
|
}
|
|
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
|
|
require.NoError(t, err)
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
})
|
|
|
|
t.Run("error_when_checking_legacy_table", func(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
mock.ExpectQuery("SELECT EXISTS \\(").
|
|
WithArgs("schema_migrations").
|
|
WillReturnError(errors.New("exists failed"))
|
|
|
|
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "check schema_migrations")
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
})
|
|
|
|
t.Run("error_when_counting_atlas_rows", func(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
mock.ExpectQuery("SELECT EXISTS \\(").
|
|
WithArgs("schema_migrations").
|
|
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
|
mock.ExpectQuery("SELECT EXISTS \\(").
|
|
WithArgs("atlas_schema_revisions").
|
|
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
|
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
|
|
WillReturnError(errors.New("count failed"))
|
|
|
|
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "count atlas_schema_revisions")
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
})
|
|
|
|
t.Run("error_when_creating_atlas_table", func(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
mock.ExpectQuery("SELECT EXISTS \\(").
|
|
WithArgs("schema_migrations").
|
|
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
|
mock.ExpectQuery("SELECT EXISTS \\(").
|
|
WithArgs("atlas_schema_revisions").
|
|
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
|
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
|
|
WillReturnError(errors.New("create failed"))
|
|
|
|
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "create atlas_schema_revisions")
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
})
|
|
|
|
t.Run("error_when_inserting_baseline", func(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
mock.ExpectQuery("SELECT EXISTS \\(").
|
|
WithArgs("schema_migrations").
|
|
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
|
mock.ExpectQuery("SELECT EXISTS \\(").
|
|
WithArgs("atlas_schema_revisions").
|
|
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
|
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
|
|
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
|
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
|
|
WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()).
|
|
WillReturnError(errors.New("insert failed"))
|
|
|
|
fsys := fstest.MapFS{
|
|
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
|
|
}
|
|
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "insert atlas baseline")
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
})
|
|
}
|
|
|
|
func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
prepareMigrationsBootstrapExpectations(mock)
|
|
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
|
WithArgs("001_init.sql").
|
|
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum"))
|
|
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
|
WithArgs(migrationsAdvisoryLockID).
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
|
|
fsys := fstest.MapFS{
|
|
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
|
|
}
|
|
err = applyMigrationsFS(context.Background(), db, fsys)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "checksum mismatch")
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
}
|
|
|
|
func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
prepareMigrationsBootstrapExpectations(mock)
|
|
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
|
WithArgs("001_err.sql").
|
|
WillReturnError(errors.New("query failed"))
|
|
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
|
WithArgs(migrationsAdvisoryLockID).
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
|
|
fsys := fstest.MapFS{
|
|
"001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")},
|
|
}
|
|
err = applyMigrationsFS(context.Background(), db, fsys)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "check migration 001_err.sql")
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
}
|
|
|
|
func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
prepareMigrationsBootstrapExpectations(mock)
|
|
|
|
alreadySQL := "CREATE TABLE t(id int);"
|
|
checksum := migrationChecksum(alreadySQL)
|
|
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
|
WithArgs("001_already.sql").
|
|
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum))
|
|
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
|
WithArgs(migrationsAdvisoryLockID).
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
|
|
fsys := fstest.MapFS{
|
|
"000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")},
|
|
"001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)},
|
|
}
|
|
err = applyMigrationsFS(context.Background(), db, fsys)
|
|
require.NoError(t, err)
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
}
|
|
|
|
func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
prepareMigrationsBootstrapExpectations(mock)
|
|
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
|
WithArgs(migrationsAdvisoryLockID).
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
|
|
fsys := fstest.MapFS{
|
|
"001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
|
|
}
|
|
err = applyMigrationsFS(context.Background(), db, fsys)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "read migration 001_bad.sql")
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
}
|
|
|
|
func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) {
|
|
t.Run("context_cancelled_while_not_locked", func(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
|
WithArgs(migrationsAdvisoryLockID).
|
|
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
|
defer cancel()
|
|
err = pgAdvisoryLock(ctx, db)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "acquire migrations lock")
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
})
|
|
|
|
t.Run("unlock_exec_error", func(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
|
WithArgs(migrationsAdvisoryLockID).
|
|
WillReturnError(errors.New("unlock failed"))
|
|
|
|
err = pgAdvisoryUnlock(context.Background(), db)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "release migrations lock")
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
})
|
|
|
|
t.Run("acquire_lock_after_retry", func(t *testing.T) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer func() { _ = db.Close() }()
|
|
|
|
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
|
WithArgs(migrationsAdvisoryLockID).
|
|
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
|
|
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
|
WithArgs(migrationsAdvisoryLockID).
|
|
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3)
|
|
defer cancel()
|
|
start := time.Now()
|
|
err = pgAdvisoryLock(ctx, db)
|
|
require.NoError(t, err)
|
|
require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval)
|
|
require.NoError(t, mock.ExpectationsWereMet())
|
|
})
|
|
}
|
|
|
|
func migrationChecksum(content string) string {
|
|
sum := sha256.Sum256([]byte(strings.TrimSpace(content)))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|