Skip to content

Commit

Permalink
feat(options): allow customizing the names of the migration tables (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
robinjoseph08 authored Oct 17, 2024
1 parent 6fa28e1 commit 49ca150
Show file tree
Hide file tree
Showing 15 changed files with 264 additions and 152 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ postgres:
-p 5432:5432 \
postgres:11

.PHONY: psql
psql:
@echo "---> Running psql"
psql -h localhost -p 5432 -U $(TEST_DATABASE_USER) -d $(TEST_DATABASE_NAME)

.PHONY: release
release:
@echo "---> Creating new release"
Expand Down
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ files to be saved in (which will be the same directory of the main package, e.g.
`example`), an instance of `*pg.DB`, and `os.Args`; and log any potential errors
that could be returned.

Once this has been set up, then you can use the `create`, `migrate`, `status`, `rollback`,
`help` commands like so:
You can also call `migrations.RunWithOptions` to configure the way that the
migrations run (e.g. customize the name of the migration tables).

Once this has been set up, then you can use the `create`, `migrate`, `status`,
`rollback`, `help` commands like so:

```
$ go run example/*.go create create_users_table
Expand Down
2 changes: 1 addition & 1 deletion create.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func init() {
}
`

func create(directory, name string) error {
func (m *migrator) create(directory, name string) error {
version := time.Now().UTC().Format(timeFormat)
fullname := fmt.Sprintf("%s_%s", version, name)
filename := path.Join(directory, fullname+".go")
Expand Down
3 changes: 2 additions & 1 deletion create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ func TestCreate(t *testing.T) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
tmp := os.TempDir()
name := fmt.Sprintf("create_test_migration_%d", r.Int())
m := newMigrator(nil, RunOptions{})

err := create(tmp, name)
err := m.create(tmp, name)
assert.Nil(t, err)

files, err := os.ReadDir(tmp)
Expand Down
109 changes: 58 additions & 51 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ func Register(name string, up, down func(orm.DB) error, opts MigrationOptions) {
})
}

func migrate(db *pg.DB) (err error) {
func (m *migrator) migrate() (err error) {
// sort the registered migrations by name (which will sort by the
// timestamp in their names)
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Name < migrations[j].Name
})

// look at the migrations table to see the already run migrations
completed, err := getCompletedMigrations(db)
completed, err := m.getCompletedMigrations()
if err != nil {
return err
}
Expand All @@ -46,118 +46,125 @@ func migrate(db *pg.DB) (err error) {
}

// acquire the migration lock from the migrations_lock table
err = acquireLock(db)
err = m.acquireLock()
if err != nil {
return err
}
defer func() {
e := releaseLock(db)
e := m.releaseLock()
if e != nil && err == nil {
err = e
}
}()

// find the last batch number
batch, err := getLastBatchNumber(db)
batch, err := m.getLastBatchNumber()
if err != nil {
return err
}
batch++

fmt.Printf("Running batch %d with %d migration(s)...\n", batch, len(uncompleted))

for _, m := range uncompleted {
m.Batch = batch
for _, mig := range uncompleted {
var err error
if m.DisableTransaction {
err = m.Up(db)
if mig.DisableTransaction {
err = mig.Up(m.db)
} else {
err = db.RunInTransaction(db.Context(), func(tx *pg.Tx) error {
return m.Up(tx)
err = m.db.RunInTransaction(m.db.Context(), func(tx *pg.Tx) error {
return mig.Up(tx)
})
}
if err != nil {
return fmt.Errorf("%s: %s", m.Name, err)
return fmt.Errorf("%s: %s", mig.Name, err)
}

m.CompletedAt = time.Now()
_, err = db.Model(m).Insert()
migrationMap := map[string]interface{}{
"name": mig.Name,
"batch": batch,
"completed_at": time.Now(),
}
_, err = m.db.
Model(&migrationMap).
Table(m.opts.MigrationsTableName).
Insert()
if err != nil {
return fmt.Errorf("%s: %s", m.Name, err)
return fmt.Errorf("%s: %s", mig.Name, err)
}
fmt.Printf("Finished running %q\n", m.Name)
fmt.Printf("Finished running %q\n", mig.Name)
}

return nil
}

func getCompletedMigrations(db orm.DB) ([]*migration, error) {
func (m *migrator) getCompletedMigrations() ([]*migration, error) {
var completed []*migration

err := db.
Model(&completed).
err := orm.NewQuery(m.db).
Table(m.opts.MigrationsTableName).
Order("id").
Select()
Select(&completed)
if err != nil {
return nil, err
}

return completed, nil
}

func filterMigrations(all, subset []*migration, wantCompleted bool) []*migration {
subsetMap := map[string]bool{}

for _, c := range subset {
subsetMap[c.Name] = true
}

var d []*migration

for _, a := range all {
if subsetMap[a.Name] == wantCompleted {
d = append(d, a)
}
}

return d
}

func acquireLock(db *pg.DB) error {
l := lock{ID: lockID, IsLocked: true}

result, err := db.Model(&l).
func (m *migrator) acquireLock() error {
l := map[string]interface{}{"is_locked": true}
result, err := m.db.
Model(&l).
Table(m.opts.MigrationLockTableName).
Column("is_locked").
WherePK().
Where("id = ?", lockID).
Where("is_locked = ?", false).
Update()

if err != nil {
return err
}

if result.RowsAffected() == 0 {
return ErrAlreadyLocked
}

return nil
}

func releaseLock(db orm.DB) error {
l := lock{ID: lockID, IsLocked: false}
_, err := db.Model(&l).
WherePK().
func (m *migrator) releaseLock() error {
l := map[string]interface{}{"is_locked": false}
_, err := m.db.
Model(&l).
Table(m.opts.MigrationLockTableName).
Column("is_locked").
Where("id = ?", lockID).
Update()
return err
}

func getLastBatchNumber(db orm.DB) (int32, error) {
func (m *migrator) getLastBatchNumber() (int32, error) {
var res struct{ Batch int32 }
err := db.Model(&migration{}).
err := orm.NewQuery(m.db).
Table(m.opts.MigrationsTableName).
ColumnExpr("COALESCE(MAX(batch), 0) AS batch").
Select(&res)
if err != nil {
return 0, err
}
return res.Batch, nil
}

func filterMigrations(all, subset []*migration, wantCompleted bool) []*migration {
subsetMap := map[string]bool{}
for _, c := range subset {
subsetMap[c.Name] = true
}

var d []*migration
for _, a := range all {
if subsetMap[a.Name] == wantCompleted {
d = append(d, a)
}
}

return d
}
26 changes: 14 additions & 12 deletions migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ func TestMigrate(t *testing.T) {
User: os.Getenv("TEST_DATABASE_USER"),
Database: os.Getenv("TEST_DATABASE_NAME"),
})

db.AddQueryHook(logQueryHook{})
m := newMigrator(db, RunOptions{})

err := ensureMigrationTables(db)
err := m.ensureMigrationTables()
require.Nil(t, err)

defer clearMigrations(t, db)
Expand All @@ -65,7 +65,7 @@ func TestMigrate(t *testing.T) {
{Name: "123", Up: noopMigration, Down: noopMigration},
}

err := migrate(db)
err := m.migrate()
assert.Nil(tt, err)

assert.Equal(tt, "123", migrations[0].Name)
Expand All @@ -83,7 +83,7 @@ func TestMigrate(t *testing.T) {
_, err := db.Model(migrations[0]).Insert()
assert.Nil(tt, err)

err = migrate(db)
err = m.migrate()
assert.Nil(tt, err)

var m []*migration
Expand All @@ -105,7 +105,7 @@ func TestMigrate(t *testing.T) {
_, err := db.Model(&migrations).Insert()
assert.Nil(tt, err)

err = migrate(db)
err = m.migrate()
assert.Nil(tt, err)

count, err := db.Model(&migration{}).Where("batch = 2").Count()
Expand All @@ -121,11 +121,11 @@ func TestMigrate(t *testing.T) {
{Name: "456", Up: noopMigration, Down: noopMigration},
}

err := acquireLock(db)
err := m.acquireLock()
assert.Nil(tt, err)
defer releaseLock(db)
defer m.releaseLock()

err = migrate(db)
err = m.migrate()
assert.Equal(tt, ErrAlreadyLocked, err)
})

Expand All @@ -141,10 +141,10 @@ func TestMigrate(t *testing.T) {
_, err := db.Model(migrations[0]).Insert()
assert.Nil(tt, err)

err = migrate(db)
err = m.migrate()
assert.Nil(tt, err)

batch, err := getLastBatchNumber(db)
batch, err := m.getLastBatchNumber()
assert.Nil(tt, err)
assert.Equal(tt, batch, int32(6))

Expand All @@ -160,7 +160,7 @@ func TestMigrate(t *testing.T) {
{Name: "123", Up: erringMigration, Down: noopMigration, DisableTransaction: false},
}

err := migrate(db)
err := m.migrate()
assert.EqualError(tt, err, "123: error")

assertTable(tt, db, "test_table", false)
Expand All @@ -173,7 +173,7 @@ func TestMigrate(t *testing.T) {
{Name: "123", Up: erringMigration, Down: noopMigration, DisableTransaction: true},
}

err := migrate(db)
err := m.migrate()
assert.EqualError(tt, err, "123: error")

assertTable(tt, db, "test_table", true)
Expand Down Expand Up @@ -207,6 +207,8 @@ func clearMigrations(t *testing.T, db *pg.DB) {

_, err := db.Exec("DELETE FROM migrations")
assert.Nil(t, err)
_, err = db.Exec("UPDATE migration_lock SET is_locked = FALSE")
assert.Nil(t, err)
_, err = db.Exec("DROP TABLE IF EXISTS test_table")
assert.Nil(t, err)
}
Expand Down
Loading

0 comments on commit 49ca150

Please sign in to comment.