Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch interfaces to accept conn pools #90

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ if err != nil {
}
defer tempDbFactory.Close()
// Generate the migration plan
plan, err := diff.GeneratePlan(ctx, conn, tempDbFactory, ddl,
plan, err := diff.GeneratePlan(ctx, connPool, tempDbFactory, ddl,
diff.WithDataPackNewTables(),
)
if err != nil {
Expand Down
8 changes: 1 addition & 7 deletions cmd/pg-schema-diff/plan_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,7 @@ func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnCo
}
defer connPool.Close()

conn, err := connPool.Conn(ctx)
if err != nil {
return diff.Plan{}, err
}
defer conn.Close()

plan, err := diff.GeneratePlan(ctx, conn, tempDbFactory, ddl,
plan, err := diff.GeneratePlan(ctx, connPool, tempDbFactory, ddl,
diff.WithDataPackNewTables(),
)
if err != nil {
Expand Down
6 changes: 2 additions & 4 deletions internal/migration_acceptance_tests/acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ func (suite *acceptanceTestSuite) runSubtest(tc acceptanceTestCase, expects expe
oldDBConnPool, err := sql.Open("pgx", oldDb.GetDSN())
suite.Require().NoError(err)
defer oldDBConnPool.Close()
oldDbConn, _ := oldDBConnPool.Conn(context.Background())
defer oldDbConn.Close()

tempDbFactory, err := tempdb.NewOnInstanceFactory(context.Background(), func(ctx context.Context, dbName string) (*sql.DB, error) {
return sql.Open("pgx", suite.pgEngine.GetPostgresDatabaseConnOpts().With("dbname", dbName).ToDSN())
Expand All @@ -122,7 +120,7 @@ func (suite *acceptanceTestSuite) runSubtest(tc acceptanceTestCase, expects expe
suite.Require().NoError(tempDbFactory.Close())
}(tempDbFactory)

plan, err := diff.GeneratePlan(context.Background(), oldDbConn, tempDbFactory, tc.newSchemaDDL, planOpts...)
plan, err := diff.GeneratePlan(context.Background(), oldDBConnPool, tempDbFactory, tc.newSchemaDDL, planOpts...)

if expects.planErrorIs != nil || len(expects.planErrorContains) > 0 {
if expects.planErrorIs != nil {
Expand Down Expand Up @@ -166,7 +164,7 @@ func (suite *acceptanceTestSuite) runSubtest(tc acceptanceTestCase, expects expe
}

// Make sure no diff is found if we try to regenerate a plan
plan, err = diff.GeneratePlan(context.Background(), oldDbConn, tempDbFactory, tc.newSchemaDDL, planOpts...)
plan, err = diff.GeneratePlan(context.Background(), oldDBConnPool, tempDbFactory, tc.newSchemaDDL, planOpts...)
suite.Require().NoError(err)
suite.Empty(plan.Statements, prettySprintPlan(plan))
}
Expand Down
16 changes: 13 additions & 3 deletions pkg/diff/plan_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ import (
"github.com/stripe/pg-schema-diff/pkg/tempdb"
)

// SQLQueryable represents a queryable database. It is recommended to use *sql.DB or *sql.Conn.
// In a future major version update, we will probably deprecate *sql.Conn support and only support *sql.DB.
type SQLQueryable interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}

type (
planOptions struct {
dataPackNewTables bool
Expand Down Expand Up @@ -60,12 +69,13 @@ func WithLogger(logger log.Logger) PlanOpt {
// GeneratePlan generates a migration plan to migrate the database to the target schema.
//
// Parameters:
// conn: connection to the target database you wish to migrate.
// sqlQueryable: The target database to generate the diff for. It is recommended to pass in *sql.DB of the db you
// wish to migrate.
// tempDbFactory: used to create a temporary database instance to extract the schema from the new DDL and validate the
// migration plan. It is recommended to use tempdb.NewOnInstanceFactory, or you can provide your own.
// newDDL: DDL encoding the new schema
// opts: Additional options to configure the plan generation
func GeneratePlan(ctx context.Context, conn *sql.Conn, tempDbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) {
func GeneratePlan(ctx context.Context, sqlQueryable SQLQueryable, tempDbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) {
planOptions := &planOptions{
validatePlan: true,
ignoreChangesToColOrder: true,
Expand All @@ -75,7 +85,7 @@ func GeneratePlan(ctx context.Context, conn *sql.Conn, tempDbFactory tempdb.Fact
opt(planOptions)
}

currentSchema, err := schema.GetPublicSchema(ctx, conn)
currentSchema, err := schema.GetPublicSchema(ctx, sqlQueryable)
if err != nil {
return Plan{}, fmt.Errorf("getting current schema: %w", err)
}
Expand Down
11 changes: 5 additions & 6 deletions pkg/diff/plan_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,23 @@ func (suite *simpleMigratorTestSuite) TestPlanAndApplyMigration() {

suite.mustApplyDDLToTestDb([]string{initialDDL})

conn, poolCloser := suite.mustGetTestDBConn()
defer poolCloser.Close()
defer conn.Close()
connPool := suite.mustGetTestDBPool()
defer connPool.Close()

tempDbFactory := suite.mustBuildTempDbFactory(context.Background())
defer tempDbFactory.Close()

plan, err := diff.GeneratePlan(context.Background(), conn, tempDbFactory, []string{newSchemaDDL})
plan, err := diff.GeneratePlan(context.Background(), connPool, tempDbFactory, []string{newSchemaDDL})
suite.NoError(err)

// Run the migration
for _, stmt := range plan.Statements {
_, err = conn.ExecContext(context.Background(), stmt.ToSQL())
_, err = connPool.ExecContext(context.Background(), stmt.ToSQL())
suite.Require().NoError(err)
}
// Ensure that some sort of migration ran. we're really not testing the correctness of the
// migration in this test suite
_, err = conn.ExecContext(context.Background(),
_, err = connPool.ExecContext(context.Background(),
"SELECT new_column FROM foobar;")
suite.NoError(err)
}
Expand Down