diff --git a/Makefile b/Makefile index cb51332..b69ef5f 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,11 @@ postgres: -p 5432:5432 \ postgres:11 +.PHONY: psql +psql: + @echo "---> Running psql" + psql -h localhost -p 5432 -U $(TEST_DATABASE_USER) -d $(TEST_DATABASE_NAME) + .PHONY: release release: @echo "---> Creating new release" diff --git a/README.md b/README.md index a184783..792577c 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,11 @@ files to be saved in (which will be the same directory of the main package, e.g. `example`), an instance of `*pg.DB`, and `os.Args`; and log any potential errors that could be returned. -Once this has been set up, then you can use the `create`, `migrate`, `status`, `rollback`, -`help` commands like so: +You can also call `migrations.RunWithOptions` to configure the way that the +migrations run (e.g. customize the name of the migration tables). + +Once this has been set up, then you can use the `create`, `migrate`, `status`, +`rollback`, `help` commands like so: ``` $ go run example/*.go create create_users_table diff --git a/create.go b/create.go index 05d8ed6..a99d57a 100644 --- a/create.go +++ b/create.go @@ -33,7 +33,7 @@ func init() { } ` -func create(directory, name string) error { +func (m *migrator) create(directory, name string) error { version := time.Now().UTC().Format(timeFormat) fullname := fmt.Sprintf("%s_%s", version, name) filename := path.Join(directory, fullname+".go") diff --git a/create_test.go b/create_test.go index 252bf0b..cad7371 100644 --- a/create_test.go +++ b/create_test.go @@ -15,8 +15,9 @@ func TestCreate(t *testing.T) { r := rand.New(rand.NewSource(time.Now().UnixNano())) tmp := os.TempDir() name := fmt.Sprintf("create_test_migration_%d", r.Int()) + m := newMigrator(nil, RunOptions{}) - err := create(tmp, name) + err := m.create(tmp, name) assert.Nil(t, err) files, err := os.ReadDir(tmp) diff --git a/migrate.go b/migrate.go index 514f09a..ecf586c 100644 --- a/migrate.go +++ b/migrate.go @@ -22,7 +22,7 @@ func Register(name string, up, down func(orm.DB) error, opts MigrationOptions) { }) } -func migrate(db *pg.DB) (err error) { +func (m *migrator) migrate() (err error) { // sort the registered migrations by name (which will sort by the // timestamp in their names) sort.Slice(migrations, func(i, j int) bool { @@ -30,7 +30,7 @@ func migrate(db *pg.DB) (err error) { }) // look at the migrations table to see the already run migrations - completed, err := getCompletedMigrations(db) + completed, err := m.getCompletedMigrations() if err != nil { return err } @@ -46,19 +46,19 @@ func migrate(db *pg.DB) (err error) { } // acquire the migration lock from the migrations_lock table - err = acquireLock(db) + err = m.acquireLock() if err != nil { return err } defer func() { - e := releaseLock(db) + e := m.releaseLock() if e != nil && err == nil { err = e } }() // find the last batch number - batch, err := getLastBatchNumber(db) + batch, err := m.getLastBatchNumber() if err != nil { return err } @@ -66,38 +66,44 @@ func migrate(db *pg.DB) (err error) { fmt.Printf("Running batch %d with %d migration(s)...\n", batch, len(uncompleted)) - for _, m := range uncompleted { - m.Batch = batch + for _, mig := range uncompleted { var err error - if m.DisableTransaction { - err = m.Up(db) + if mig.DisableTransaction { + err = mig.Up(m.db) } else { - err = db.RunInTransaction(db.Context(), func(tx *pg.Tx) error { - return m.Up(tx) + err = m.db.RunInTransaction(m.db.Context(), func(tx *pg.Tx) error { + return mig.Up(tx) }) } if err != nil { - return fmt.Errorf("%s: %s", m.Name, err) + return fmt.Errorf("%s: %s", mig.Name, err) } - m.CompletedAt = time.Now() - _, err = db.Model(m).Insert() + migrationMap := map[string]interface{}{ + "name": mig.Name, + "batch": batch, + "completed_at": time.Now(), + } + _, err = m.db. + Model(&migrationMap). + Table(m.opts.MigrationsTableName). + Insert() if err != nil { - return fmt.Errorf("%s: %s", m.Name, err) + return fmt.Errorf("%s: %s", mig.Name, err) } - fmt.Printf("Finished running %q\n", m.Name) + fmt.Printf("Finished running %q\n", mig.Name) } return nil } -func getCompletedMigrations(db orm.DB) ([]*migration, error) { +func (m *migrator) getCompletedMigrations() ([]*migration, error) { var completed []*migration - err := db. - Model(&completed). + err := orm.NewQuery(m.db). + Table(m.opts.MigrationsTableName). Order("id"). - Select() + Select(&completed) if err != nil { return nil, err } @@ -105,37 +111,18 @@ func getCompletedMigrations(db orm.DB) ([]*migration, error) { return completed, nil } -func filterMigrations(all, subset []*migration, wantCompleted bool) []*migration { - subsetMap := map[string]bool{} - - for _, c := range subset { - subsetMap[c.Name] = true - } - - var d []*migration - - for _, a := range all { - if subsetMap[a.Name] == wantCompleted { - d = append(d, a) - } - } - - return d -} - -func acquireLock(db *pg.DB) error { - l := lock{ID: lockID, IsLocked: true} - - result, err := db.Model(&l). +func (m *migrator) acquireLock() error { + l := map[string]interface{}{"is_locked": true} + result, err := m.db. + Model(&l). + Table(m.opts.MigrationLockTableName). Column("is_locked"). - WherePK(). + Where("id = ?", lockID). Where("is_locked = ?", false). Update() - if err != nil { return err } - if result.RowsAffected() == 0 { return ErrAlreadyLocked } @@ -143,17 +130,21 @@ func acquireLock(db *pg.DB) error { return nil } -func releaseLock(db orm.DB) error { - l := lock{ID: lockID, IsLocked: false} - _, err := db.Model(&l). - WherePK(). +func (m *migrator) releaseLock() error { + l := map[string]interface{}{"is_locked": false} + _, err := m.db. + Model(&l). + Table(m.opts.MigrationLockTableName). + Column("is_locked"). + Where("id = ?", lockID). Update() return err } -func getLastBatchNumber(db orm.DB) (int32, error) { +func (m *migrator) getLastBatchNumber() (int32, error) { var res struct{ Batch int32 } - err := db.Model(&migration{}). + err := orm.NewQuery(m.db). + Table(m.opts.MigrationsTableName). ColumnExpr("COALESCE(MAX(batch), 0) AS batch"). Select(&res) if err != nil { @@ -161,3 +152,19 @@ func getLastBatchNumber(db orm.DB) (int32, error) { } return res.Batch, nil } + +func filterMigrations(all, subset []*migration, wantCompleted bool) []*migration { + subsetMap := map[string]bool{} + for _, c := range subset { + subsetMap[c.Name] = true + } + + var d []*migration + for _, a := range all { + if subsetMap[a.Name] == wantCompleted { + d = append(d, a) + } + } + + return d +} diff --git a/migrate_test.go b/migrate_test.go index 0dc1c11..a338d87 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -48,10 +48,10 @@ func TestMigrate(t *testing.T) { User: os.Getenv("TEST_DATABASE_USER"), Database: os.Getenv("TEST_DATABASE_NAME"), }) - db.AddQueryHook(logQueryHook{}) + m := newMigrator(db, RunOptions{}) - err := ensureMigrationTables(db) + err := m.ensureMigrationTables() require.Nil(t, err) defer clearMigrations(t, db) @@ -65,7 +65,7 @@ func TestMigrate(t *testing.T) { {Name: "123", Up: noopMigration, Down: noopMigration}, } - err := migrate(db) + err := m.migrate() assert.Nil(tt, err) assert.Equal(tt, "123", migrations[0].Name) @@ -83,7 +83,7 @@ func TestMigrate(t *testing.T) { _, err := db.Model(migrations[0]).Insert() assert.Nil(tt, err) - err = migrate(db) + err = m.migrate() assert.Nil(tt, err) var m []*migration @@ -105,7 +105,7 @@ func TestMigrate(t *testing.T) { _, err := db.Model(&migrations).Insert() assert.Nil(tt, err) - err = migrate(db) + err = m.migrate() assert.Nil(tt, err) count, err := db.Model(&migration{}).Where("batch = 2").Count() @@ -121,11 +121,11 @@ func TestMigrate(t *testing.T) { {Name: "456", Up: noopMigration, Down: noopMigration}, } - err := acquireLock(db) + err := m.acquireLock() assert.Nil(tt, err) - defer releaseLock(db) + defer m.releaseLock() - err = migrate(db) + err = m.migrate() assert.Equal(tt, ErrAlreadyLocked, err) }) @@ -141,10 +141,10 @@ func TestMigrate(t *testing.T) { _, err := db.Model(migrations[0]).Insert() assert.Nil(tt, err) - err = migrate(db) + err = m.migrate() assert.Nil(tt, err) - batch, err := getLastBatchNumber(db) + batch, err := m.getLastBatchNumber() assert.Nil(tt, err) assert.Equal(tt, batch, int32(6)) @@ -160,7 +160,7 @@ func TestMigrate(t *testing.T) { {Name: "123", Up: erringMigration, Down: noopMigration, DisableTransaction: false}, } - err := migrate(db) + err := m.migrate() assert.EqualError(tt, err, "123: error") assertTable(tt, db, "test_table", false) @@ -173,7 +173,7 @@ func TestMigrate(t *testing.T) { {Name: "123", Up: erringMigration, Down: noopMigration, DisableTransaction: true}, } - err := migrate(db) + err := m.migrate() assert.EqualError(tt, err, "123: error") assertTable(tt, db, "test_table", true) @@ -207,6 +207,8 @@ func clearMigrations(t *testing.T, db *pg.DB) { _, err := db.Exec("DELETE FROM migrations") assert.Nil(t, err) + _, err = db.Exec("UPDATE migration_lock SET is_locked = FALSE") + assert.Nil(t, err) _, err = db.Exec("DROP TABLE IF EXISTS test_table") assert.Nil(t, err) } diff --git a/migrations.go b/migrations.go index e4b15ee..fb956d0 100644 --- a/migrations.go +++ b/migrations.go @@ -22,6 +22,19 @@ type MigrationOptions struct { DisableTransaction bool } +// RunOptions allows settings to be configured for the environment that the migrations are run. +type RunOptions struct { + // Set this to configure the table name of the migrations table. The default is `migrations`. Changing this after + // you already have a migrations table does NOT rename it; it assumes you're starting fresh. + MigrationsTableName string + // Set this to configure the table name of the lock table. The default is `migration_lock`. Changing this after you + // already have a migration lock table does NOT rename it; it just creates a new one and leaves any existing ones + // alone. + MigrationLockTableName string +} + +// migration doesn't map to the table that we create to keep track of migrations. To see details of that table, see +// setup.go. This struct has tableName and pg tags because it makes tests easier, but it's not used in non-test code. type migration struct { tableName struct{} `pg:"migrations,alias:migrations"` @@ -35,51 +48,49 @@ type migration struct { DisableTransaction bool `pg:"-"` } -type lock struct { - tableName struct{} `pg:"migration_lock,alias:migration_lock"` - - ID string - IsLocked bool `pg:",use_zero,notnull"` -} - const lockID = "lock" -// Run takes in a directory and an argument slice and runs the appropriate command. +// Run takes in a directory and an argument slice and runs the appropriate command with default options. func Run(db *pg.DB, directory string, args []string) error { - cmd := "" + return RunWithOptions(db, directory, args, RunOptions{}) +} +// RunWithOptions takes in a directory, an argument slice, and run options and runs the appropriate command. +func RunWithOptions(db *pg.DB, directory string, args []string, opts RunOptions) error { + cmd := "" if len(args) > 1 { cmd = args[1] } + m := newMigrator(db, opts) + switch cmd { case "migrate": - err := ensureMigrationTables(db) + err := m.ensureMigrationTables() if err != nil { return err } - return migrate(db) + return m.migrate() case "create": if len(args) < 3 { return ErrCreateRequiresName } name := args[2] - return create(directory, name) + return m.create(directory, name) case "rollback": - err := ensureMigrationTables(db) + err := m.ensureMigrationTables() if err != nil { return err } - - return rollback(db) + return m.rollback() case "status": - err := ensureMigrationTables(db) + err := m.ensureMigrationTables() if err != nil { return err } - return status(db, os.Stdout) + return m.status(os.Stdout) default: help(directory) return nil diff --git a/migrations_test.go b/migrations_test.go index db31147..d3c6530 100644 --- a/migrations_test.go +++ b/migrations_test.go @@ -31,3 +31,52 @@ func TestRun(t *testing.T) { err = Run(db, tmp, []string{"cmd", "rollback"}) assert.Nil(t, err) } + +func TestRunWithOptions(t *testing.T) { + tmp := os.TempDir() + db := pg.Connect(&pg.Options{ + Addr: "localhost:5432", + User: os.Getenv("TEST_DATABASE_USER"), + Database: os.Getenv("TEST_DATABASE_NAME"), + }) + db.AddQueryHook(logQueryHook{}) + + t.Run("default", func(tt *testing.T) { + dropMigrationTables(tt, db) + + err := RunWithOptions(db, tmp, []string{"cmd", "migrate"}, RunOptions{}) + assert.Nil(tt, err) + assertTable(tt, db, "migrations", true) + assertTable(tt, db, "migration_lock", true) + assertTable(tt, db, "custom_migrations", false) + assertTable(tt, db, "custom_migration_lock", false) + }) + + t.Run("custom tables - migrate", func(tt *testing.T) { + dropMigrationTables(tt, db) + + err := RunWithOptions(db, tmp, []string{"cmd", "migrate"}, RunOptions{ + MigrationsTableName: "custom_migrations", + MigrationLockTableName: "custom_migration_lock", + }) + assert.Nil(tt, err) + assertTable(tt, db, "custom_migrations", true) + assertTable(tt, db, "custom_migration_lock", true) + assertTable(tt, db, "migrations", false) + assertTable(tt, db, "migration_lock", false) + }) + + t.Run("custom tables - rollback", func(tt *testing.T) { + dropMigrationTables(tt, db) + + err := RunWithOptions(db, tmp, []string{"cmd", "rollback"}, RunOptions{ + MigrationsTableName: "custom_migrations", + MigrationLockTableName: "custom_migration_lock", + }) + assert.Nil(tt, err) + assertTable(tt, db, "custom_migrations", true) + assertTable(tt, db, "custom_migration_lock", true) + assertTable(tt, db, "migrations", false) + assertTable(tt, db, "migration_lock", false) + }) +} diff --git a/migrator.go b/migrator.go new file mode 100644 index 0000000..635748c --- /dev/null +++ b/migrator.go @@ -0,0 +1,30 @@ +package migrations + +import ( + "strings" + + "github.com/go-pg/pg/v10" +) + +type migrator struct { + db *pg.DB + opts RunOptions +} + +func newMigrator(db *pg.DB, opts RunOptions) *migrator { + if opts.MigrationsTableName == "" { + opts.MigrationsTableName = "migrations" + } + if opts.MigrationLockTableName == "" { + opts.MigrationLockTableName = "migration_lock" + } + + return &migrator{ + db: db, + opts: opts, + } +} + +func escapeTableName(name string) string { + return strings.ReplaceAll(name, `"`, `""`) +} diff --git a/rollback.go b/rollback.go index 40d90c9..ae43f0c 100644 --- a/rollback.go +++ b/rollback.go @@ -7,7 +7,7 @@ import ( "github.com/go-pg/pg/v10" ) -func rollback(db *pg.DB) error { +func (m *migrator) rollback() error { // sort the registered migrations by name (which will sort by the // timestamp in their names) sort.Slice(migrations, func(i, j int) bool { @@ -15,19 +15,19 @@ func rollback(db *pg.DB) error { }) // look at the migrations table to see the already run migrations - completed, err := getCompletedMigrations(db) + completed, err := m.getCompletedMigrations() if err != nil { return err } // acquire the migration lock from the migrations_lock table - err = acquireLock(db) + err = m.acquireLock() if err != nil { return err } - defer releaseLock(db) + defer m.releaseLock() - batch, err := getLastBatchNumber(db) + batch, err := m.getLastBatchNumber() if err != nil { return err } @@ -42,24 +42,25 @@ func rollback(db *pg.DB) error { fmt.Printf("Rolling back batch %d with %d migration(s)...\n", batch, len(rollback)) - for _, m := range rollback { + for _, mig := range rollback { var err error - if m.DisableTransaction { - err = m.Down(db) + if mig.DisableTransaction { + err = mig.Down(m.db) } else { - err = db.RunInTransaction(db.Context(), func(tx *pg.Tx) error { - return m.Down(tx) + err = m.db.RunInTransaction(m.db.Context(), func(tx *pg.Tx) error { + return mig.Down(tx) }) } if err != nil { - return fmt.Errorf("%s: %s", m.Name, err) + return fmt.Errorf("%s: %s", mig.Name, err) } - _, err = db.Model(m).Where("name = ?", m.Name).Delete() + _, err = m.db. + Exec(fmt.Sprintf("DELETE FROM %q WHERE name = ?", escapeTableName(m.opts.MigrationsTableName)), mig.Name) if err != nil { - return fmt.Errorf("%s: %s", m.Name, err) + return fmt.Errorf("%s: %s", mig.Name, err) } - fmt.Printf("Finished rolling back %q\n", m.Name) + fmt.Printf("Finished rolling back %q\n", mig.Name) } return nil @@ -67,10 +68,9 @@ func rollback(db *pg.DB) error { func getMigrationsForBatch(migrations []*migration, batch int32) []*migration { var m []*migration - - for _, migration := range migrations { - if migration.Batch == batch { - m = append(m, migration) + for _, mig := range migrations { + if mig.Batch == batch { + m = append(m, mig) } } diff --git a/rollback_test.go b/rollback_test.go index e169bed..b540c14 100644 --- a/rollback_test.go +++ b/rollback_test.go @@ -16,8 +16,10 @@ func TestRollback(t *testing.T) { User: os.Getenv("TEST_DATABASE_USER"), Database: os.Getenv("TEST_DATABASE_NAME"), }) + db.AddQueryHook(logQueryHook{}) + m := newMigrator(db, RunOptions{}) - err := ensureMigrationTables(db) + err := m.ensureMigrationTables() require.Nil(t, err) defer clearMigrations(t, db) @@ -31,7 +33,7 @@ func TestRollback(t *testing.T) { {Name: "456", Up: noopMigration, Down: noopMigration}, } - err := rollback(db) + err := m.rollback() assert.Nil(tt, err) assert.Equal(tt, "456", migrations[0].Name) @@ -46,11 +48,11 @@ func TestRollback(t *testing.T) { {Name: "456", Up: noopMigration, Down: noopMigration}, } - err := acquireLock(db) + err := m.acquireLock() assert.Nil(tt, err) - defer releaseLock(db) + defer m.releaseLock() - err = rollback(db) + err = m.rollback() assert.Equal(tt, ErrAlreadyLocked, err) }) @@ -62,7 +64,7 @@ func TestRollback(t *testing.T) { {Name: "456", Up: noopMigration, Down: noopMigration}, } - err := rollback(db) + err := m.rollback() assert.Nil(tt, err) count, err := db.Model(&migration{}).Count() @@ -80,14 +82,14 @@ func TestRollback(t *testing.T) { {Name: "010", Up: noopMigration, Down: noopMigration}, } - m := migrations[:2] - _, err := db.Model(&m).Insert() + mig := migrations[:2] + _, err := db.Model(&mig).Insert() assert.Nil(tt, err) - err = rollback(db) + err = m.rollback() assert.Nil(tt, err) - batch, err := getLastBatchNumber(db) + batch, err := m.getLastBatchNumber() assert.Nil(tt, err) assert.Equal(tt, batch, int32(4)) @@ -106,7 +108,7 @@ func TestRollback(t *testing.T) { _, err := db.Model(&migrations).Insert() assert.Nil(tt, err) - err = rollback(db) + err = m.rollback() assert.EqualError(tt, err, "123: error") assertTable(tt, db, "test_table", false) @@ -122,7 +124,7 @@ func TestRollback(t *testing.T) { _, err := db.Model(&migrations).Insert() assert.Nil(tt, err) - err = rollback(db) + err = m.rollback() assert.EqualError(tt, err, "123: error") assertTable(tt, db, "test_table", true) diff --git a/setup.go b/setup.go index 9bf7cc1..1cb6dfc 100644 --- a/setup.go +++ b/setup.go @@ -1,40 +1,49 @@ package migrations import ( - "github.com/go-pg/pg/v10" + "fmt" + "github.com/go-pg/pg/v10/orm" ) -func ensureMigrationTables(db *pg.DB) error { - exists, err := checkIfTableExists("migrations", db) +func (m *migrator) ensureMigrationTables() error { + exists, err := m.checkIfTableExists(m.opts.MigrationsTableName) if err != nil { return err } if !exists { - err = createTable(&migration{}, db) + _, err = m.db.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %q ("id" SERIAL PRIMARY KEY, "name" TEXT NOT NULL, "batch" INTEGER NOT NULL, "completed_at" TIMESTAMPTZ NOT NULL)`, escapeTableName(m.opts.MigrationsTableName))) if err != nil { return err } } - exists, err = checkIfTableExists("migration_lock", db) + exists, err = m.checkIfTableExists(m.opts.MigrationLockTableName) if err != nil { return err } if !exists { - err = createTable(&lock{}, db) + _, err = m.db.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %q ("id" TEXT PRIMARY KEY, "is_locked" BOOLEAN NOT NULL)`, escapeTableName(m.opts.MigrationLockTableName))) if err != nil { return err } } - count, err := db.Model(&lock{}).Count() + count, err := orm.NewQuery(m.db). + Table(m.opts.MigrationLockTableName). + Count() if err != nil { return err } if count == 0 { - l := lock{ID: lockID, IsLocked: false} - _, err = db.Model(&l).Insert() + l := map[string]interface{}{ + "id": lockID, + "is_locked": false, + } + _, err = m.db. + Model(&l). + Table(m.opts.MigrationLockTableName). + Insert() if err != nil { return err } @@ -43,8 +52,8 @@ func ensureMigrationTables(db *pg.DB) error { return nil } -func checkIfTableExists(name string, db orm.DB) (bool, error) { - count, err := orm.NewQuery(db). +func (m *migrator) checkIfTableExists(name string) (bool, error) { + count, err := orm.NewQuery(m.db). Table("information_schema.tables"). Where("table_name = ?", name). Where("table_schema = current_schema"). @@ -54,8 +63,3 @@ func checkIfTableExists(name string, db orm.DB) (bool, error) { } return count > 0, nil } - -func createTable(model interface{}, db *pg.DB) error { - opts := orm.CreateTableOptions{IfNotExists: true} - return db.Model(model).CreateTable(&opts) -} diff --git a/setup_test.go b/setup_test.go index 789f5af..f635e4e 100644 --- a/setup_test.go +++ b/setup_test.go @@ -15,11 +15,12 @@ func TestEnsureMigrationTables(t *testing.T) { User: os.Getenv("TEST_DATABASE_USER"), Database: os.Getenv("TEST_DATABASE_NAME"), }) + m := newMigrator(db, RunOptions{}) // drop tables to start from a clean database dropMigrationTables(t, db) - err := ensureMigrationTables(db) + err := m.ensureMigrationTables() assert.Nil(t, err) tables := []string{"migrations", "migration_lock"} @@ -31,7 +32,7 @@ func TestEnsureMigrationTables(t *testing.T) { assertOneLock(t, db) // with existing tables, ensureMigrationTables should do anything - err = ensureMigrationTables(db) + err = m.ensureMigrationTables() assert.Nil(t, err) for _, table := range tables { @@ -44,9 +45,13 @@ func TestEnsureMigrationTables(t *testing.T) { func dropMigrationTables(t *testing.T, db *pg.DB) { t.Helper() - _, err := db.Exec("DROP TABLE migrations") + _, err := db.Exec("DROP TABLE IF EXISTS migrations") assert.Nil(t, err) - _, err = db.Exec("DROP TABLE migration_lock") + _, err = db.Exec("DROP TABLE IF EXISTS migration_lock") + assert.Nil(t, err) + _, err = db.Exec("DROP TABLE IF EXISTS custom_migrations") + assert.Nil(t, err) + _, err = db.Exec("DROP TABLE IF EXISTS custom_migration_lock") assert.Nil(t, err) } diff --git a/status.go b/status.go index 65b85cb..214106d 100644 --- a/status.go +++ b/status.go @@ -7,15 +7,9 @@ import ( "sort" "strings" "unicode/utf8" - - "github.com/go-pg/pg/v10" ) -type migrationWithStatus struct { - migration -} - -func status(db *pg.DB, w io.Writer) error { +func (m *migrator) status(w io.Writer) error { // sort the registered migrations by name (which will sort by the // timestamp in their names) sort.Slice(migrations, func(i, j int) bool { @@ -23,7 +17,7 @@ func status(db *pg.DB, w io.Writer) error { }) // look at the migrations table to see the already run migrations - completed, err := getCompletedMigrations(db) + completed, err := m.getCompletedMigrations() if err != nil { return err } @@ -35,7 +29,7 @@ func status(db *pg.DB, w io.Writer) error { return writeStatusTable(w, completed, uncompleted) } -func writeStatusTable(w io.Writer, completed []migration, uncompleted []migration) error { +func writeStatusTable(w io.Writer, completed []*migration, uncompleted []*migration) error { if len(completed)+len(uncompleted) == 0 { _, err := fmt.Fprintln(w, "No migrations found") return err @@ -76,7 +70,6 @@ func writeStatusTable(w io.Writer, completed []migration, uncompleted []migratio func maxInt(a, b int) int { if a > b { return a - } else { - return b } + return b } diff --git a/status_test.go b/status_test.go index 06ca067..d159653 100644 --- a/status_test.go +++ b/status_test.go @@ -16,20 +16,20 @@ func TestStatus(t *testing.T) { User: os.Getenv("TEST_DATABASE_USER"), Database: os.Getenv("TEST_DATABASE_NAME"), }) - db.AddQueryHook(logQueryHook{}) + m := newMigrator(db, RunOptions{}) - err := ensureMigrationTables(db) + err := m.ensureMigrationTables() require.Nil(t, err) defer clearMigrations(t, db) defer resetMigrations(t) - completed := []migration{ + completed := []*migration{ {Name: "2021_02_26_151503_dump", Up: noopMigration, Down: noopMigration, Batch: 1}, {Name: "2021_02_26_151504_create_a_dump_table_for_test", Up: noopMigration, Down: noopMigration, Batch: 2}, } - uncompleted := []migration{ + uncompleted := []*migration{ {Name: "2021_02_26_151502_create_2nd_dump_table", Up: noopMigration, Down: noopMigration}, {Name: "2021_02_26_151505_create_3rd_dump_table", Up: noopMigration, Down: noopMigration}, } @@ -49,15 +49,15 @@ func TestStatus(t *testing.T) { resetMigrations(tt) migrations = completed[:1] - err := migrate(db) + err := m.migrate() require.Nil(tt, err, "migrate: %v", err) migrations = completed[:2] - err = migrate(db) + err = m.migrate() require.Nil(tt, err, "migrate: %v", err) migrations = append(migrations, uncompleted...) bf := bytes.NewBuffer(nil) - err = status(db, bf) + err = m.status(bf) require.Nil(tt, err, "status: %v", err) got := strings.TrimSpace(bf.String())