diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 0000000..e90e11a
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,22 @@
+root = true
+
+[*]
+charset = utf-8
+end_of_line = lf
+insert_final_newline = true
+trim_trailing_whitespace = true
+
+[{*.go,Makefile,.gitmodules,go.mod,go.sum}]
+indent_style = tab
+
+[*.md]
+indent_style = tab
+trim_trailing_whitespace = false
+
+[*.{yml,yaml,json}]
+indent_style = space
+indent_size = 2
+
+[*.{js,jsx,ts,tsx,css,less,sass,scss,vue,py}]
+indent_style = space
+indent_size = 4
diff --git a/.github/workflows/code-gen.yaml b/.github/workflows/code-gen.yaml
new file mode 100644
index 0000000..1fdd72d
--- /dev/null
+++ b/.github/workflows/code-gen.yaml
@@ -0,0 +1,23 @@
+name: code_gen
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+permissions:
+ contents: read
+jobs:
+ code_gen:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Build Docker image
+ run: docker build -t pg-schema-diff-code-gen-runner -f ./build/Dockerfile.codegen .
+ - name: Run codegen
+ run: docker run -v $(pwd):/pg-schema-diff -w /pg-schema-diff pg-schema-diff-code-gen-runner
+ - name: Check for changes
+ run: |
+ chmod +x ./build/ci-scripts/assert-no-diff.sh
+ ./build/ci-scripts/assert-no-diff.sh
+ shell: bash
+
diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
new file mode 100644
index 0000000..44dfa27
--- /dev/null
+++ b/.github/workflows/lint.yaml
@@ -0,0 +1,21 @@
+name: lint
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+permissions:
+ contents: read
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/setup-go@v4
+ with:
+ go-version: '1.18'
+ cache: false
+ - uses: actions/checkout@v3
+ - name: golangci-lint
+ uses: golangci/golangci-lint-action@v3
+ with:
+ version: v1.52.2
diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml
new file mode 100644
index 0000000..80a806f
--- /dev/null
+++ b/.github/workflows/run-tests.yaml
@@ -0,0 +1,20 @@
+name: run_tests
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+jobs:
+ run_tests:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ pg_version: [14, 15]
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+ - name: Build Docker image
+ run: docker build -t pg-schema-diff-test-runner -f ./build/Dockerfile.test --build-arg POSTGRES_PACKAGE=postgresql${{ matrix.pg_version }} .
+ - name: Run tests
+ run: docker run pg-schema-diff-test-runner
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..550b268
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,26 @@
+# Mac OS X files
+.DS_Store
+
+# Binaries for programs and plugins
+bin/
+*.exe
+*.exe~
+*.dll
+*.so
+*.dylib
+
+# Test binary, build with `go test -c`
+*.test
+!/build/Dockerfile.test
+
+# Output of the go coverage tool, specifically when used with LiteIDE
+*.out
+
+# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
+.glide/
+
+# Dependency directories
+vendor/
+
+# Intellij
+.idea/
diff --git a/.golangci.yaml b/.golangci.yaml
new file mode 100644
index 0000000..0fbdc18
--- /dev/null
+++ b/.golangci.yaml
@@ -0,0 +1,9 @@
+linters:
+ disable-all: true
+ enable:
+ - goimports
+ - ineffassign
+ - staticcheck
+ - typecheck
+ - unused
+
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000..25c4774
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,76 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to making participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies both within project spaces and in public spaces
+when an individual is representing the project or its community. Examples of
+representing a project or community include using an official project e-mail
+address, posting via an official social media account, or acting as an appointed
+representative at an online or offline event. Representation of a project may be
+further defined and clarified by project maintainers.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at conduct@stripe.com. All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..c113636
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,24 @@
+# Contributing
+
+This project is in its early stages. We appreciate all the feature/bug requests we receive, but we have limited cycles
+to review direct code contributions at this time. We will try and respond to any bug reports, feature requests, and
+questions within one week.
+
+If you want to make changes yourself, follow these steps:
+
+1. [Fork](https://help.github.com/articles/fork-a-repo/) this repository and [clone](https://help.github.com/articles/cloning-a-repository/) it locally.
+2. Make your changes
+3. Test your changes
+```bash
+ docker build -t pg-schema-diff-test-runner -f ./build/Dockerfile.test --build-arg POSTGRES_PACKAGE=postgresql{14, 15} .
+ ```
+3. Submit a [pull request](https://help.github.com/articles/creating-a-pull-request-from-a-fork/)
+
+## Contributor License Agreement ([CLA](https://en.wikipedia.org/wiki/Contributor_License_Agreement))
+
+Once you have submitted a pull request, sign the CLA by clicking on the badge in the comment from [@CLAassistant](https://github.com/CLAassistant).
+
+
+
+
+Thanks for contributing to Stripe!
diff --git a/LICENSE.md b/LICENSE.md
new file mode 100644
index 0000000..34f864a
--- /dev/null
+++ b/LICENSE.md
@@ -0,0 +1,21 @@
+The MIT License
+
+Copyright (c) 2023- Stripe, Inc. (https://stripe.com)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..b75b59e
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,16 @@
+.PHONY: code_gen format lint test vendor sqlc
+
+code_gen: sqlc
+
+format:
+ goimports -w .
+
+lint:
+ golangci-lint run
+
+sqlc:
+ cd internal/queries && sqlc generate
+
+vendor:
+ go mod vendor
+
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..ecdb281
--- /dev/null
+++ b/README.md
@@ -0,0 +1,104 @@
+# pg-schema-diff
+
+Diffs Postgres database schemas and generates the the SQL required to get your database schema from point A to B. This
+enables you to take your database and migrate it to any desired schema defined with plain DDL.
+
+The tooling attempts to use native postgres migration operations and avoid locking wherever possible. Not all migrations will
+be lock-free and some might require downtime, but the hazards system will warn you ahead of time when that's the case.
+Stateful online migration techniques, like shadow tables, aren't yet supported.
+
+```
+pg-schema-diff plan --dsn "postgres://postgres:postgres@localhost:5432/postgres" --schema-dir schema
+
+################################ Generated plan ################################
+1. ALTER TABLE "foobar" ADD COLUMN "fizz" character varying(255) COLLATE "pg_catalog"."default";
+ -- Timeout: 3s
+
+2. CREATE INDEX CONCURRENTLY fizz_idx ON public.foobar USING btree (fizz);
+ -- Timeout: 20m0s
+ -- Hazard INDEX_BUILD: This might affect database performance. Concurrent index builds require a non-trivial amount of CPU, potentially affecting database performance. They also can take a while but do not lock out writes.
+```
+
+# Key features
+*Broad support for diffing & applying arbitrary postgres schemas defined in declarative DDL:*
+- Tables
+- Columns
+- Check Constraints
+- Indexes
+- Partitions
+- Functions/Triggers (functions created by extensions are ignored)
+
+*A comprehensive set of features to ensure the safety of planned migrations:*
+- Dangerous operations are flagged as hazards and must be approved before a migration can be applied.
+ - Data deletion hazards identify operations which will in some way delete or alter data.
+ - Downtime/locking hazards identify operations which will impede or stop other queries.
+ - Performance hazards identify operations which are resource intensive and might slow other queries.
+- Migration plans are validated first against a temporary database exactly as they would be performed against the real database.
+- The library is tested against an extensive suite of unit and acceptance tests.
+
+*The use of postgres native operations for zero-downtime migrations wherever possible:*
+- Concurrent index builds
+- Online index replacement
+
+# Install
+## Library
+```bash
+go get -u github.com/stripe/pg-schema-diff
+````
+## CLI
+```bash
+go install github.com/stripe/pg-schema-diff/cmd/pg-schema-diff
+```
+
+# Using CLI
+## 1. Apply schema to fresh database
+Create a directory to hold your schema files. Then, generate sql files and place them into a schema dir.
+```bash
+mkdir schema
+echo "CREATE TABLE foobar (id int);" > schema/foobar.sql
+echo "CREATE TABLE bar (id varchar(255), message TEXT NOT NULL);" > schema/bar.sql
+```
+
+Apply the schema to a fresh database. [The connection string spec can be found here](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING).
+Setting the `PGPASSWORD` env var will override any password set in the connection string and is recommended.
+```bash
+pg-schema-diff apply --dsn "postgres://postgres:postgres@localhost:5432/postgres" --schema-dir schema
+```
+
+## 2. Updating schema
+Update the SQL file(s)
+```bash
+echo "CREATE INDEX message_idx ON bar(message)" >> schema/bar.sql
+```
+
+Apply the schema. Any hazards in the generated plan must be approved
+```bash
+pg-schema-diff apply --dsn "postgres://postgres:postgres@localhost:5432/postgres" --schema-dir schema --allow-hazards INDEX_BUILD
+```
+
+# Supported Postgres versions
+- 14 (tested with 14.7)
+- 15 (tested with 15.2)
+
+Postgres v13 and below are not supported. Use at your own risk.
+
+# Unsupported migrations
+Note, the library only currently supports diffing the *public* schema. Support for diffing other schemas is on the roadmap
+
+*Unsupported*:
+- (On roadmap) Foreign key constraints
+- (On roadmap) Diffing schemas other than "public"
+- (On roadmap) Serials and sequences
+- (On roadmap) Unique constraints (unique indexes are supported but not unique constraints)
+- (On roadmap) Adding and remove partitions from an existing partitioned table
+- (On roadmap) Check constraints localized to specific partitions
+- Partitioned partitions (partitioned tables are supported but not partitioned partitions)
+- Materialized views
+- Renaming. The diffing library relies on names to identify the old and new versions of a table, index, etc. If you rename
+an object, it will be treated as a drop and an add
+
+# Contributing
+This project is in its early stages. We appreciate all the feature/bug requests we receive, but we have limited cycles
+to review direct code contributions at this time. See [Contributing](CONTRIBUTING.md) to learn more.
+
+
diff --git a/build/Dockerfile.codegen b/build/Dockerfile.codegen
new file mode 100644
index 0000000..0940591
--- /dev/null
+++ b/build/Dockerfile.codegen
@@ -0,0 +1,10 @@
+FROM golang:1.18.10-alpine3.17
+
+RUN apk update && \
+ apk add --no-cache \
+ build-base \
+ git \
+ make
+
+RUN go install github.com/kyleconroy/sqlc/cmd/sqlc@v1.13.0
+ENTRYPOINT make code_gen
diff --git a/build/Dockerfile.test b/build/Dockerfile.test
new file mode 100644
index 0000000..68c292c
--- /dev/null
+++ b/build/Dockerfile.test
@@ -0,0 +1,26 @@
+FROM golang:1.18.10-alpine3.17
+
+ARG POSTGRES_PACKAGE
+
+RUN apk update && \
+ apk add --no-cache \
+ build-base \
+ make \
+ $POSTGRES_PACKAGE \
+ postgresql-contrib \
+ postgresql-client
+
+WORKDIR /pg-schema-diff
+
+COPY . .
+
+# Download dependencies so they are cached in a layer
+RUN go mod download
+
+# Run all tests from non-root. This will also prevent Postgres from complaining when
+# we try to launch it within tests
+RUN adduser --disabled-password --gecos '' testrunner
+USER testrunner
+
+# Run tests serially so logs can be streamed. Set overall timeout to 30m (the default is 10m, which is not enough)
+ENTRYPOINT ["go", "test", "-v", "-race", "-p", "1", "./...", "-timeout", "30m"]
diff --git a/build/ci-scripts/assert-no-diff.sh b/build/ci-scripts/assert-no-diff.sh
new file mode 100755
index 0000000..b6bb518
--- /dev/null
+++ b/build/ci-scripts/assert-no-diff.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+git_status=$(git status --porcelain)
+if [[ -n $git_status ]]; then
+ echo "Changes to generated files detected $git_status"
+ exit 1
+fi
diff --git a/cmd/pg-schema-diff/apply_cmd.go b/cmd/pg-schema-diff/apply_cmd.go
new file mode 100644
index 0000000..24eda9f
--- /dev/null
+++ b/cmd/pg-schema-diff/apply_cmd.go
@@ -0,0 +1,173 @@
+package main
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/jackc/pgx/v4"
+ "github.com/spf13/cobra"
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+ "github.com/stripe/pg-schema-diff/pkg/log"
+)
+
+func buildApplyCmd() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "apply",
+ Short: "Migrate your database to the match the inputted schema (apply the schema to the database)",
+ }
+
+ connFlags := createConnFlags(cmd)
+ planFlags := createPlanFlags(cmd)
+ allowedHazardsTypesStrs := cmd.Flags().StringSlice("allow-hazards", nil,
+ "Specify the hazards that are allowed. Order does not matter, and duplicates are ignored. If the"+
+ " migration plan contains unwanted hazards (hazards not in this list), then the migration will fail to run"+
+ " (example: --allowed-hazards DELETES_DATA,INDEX_BUILD)")
+ lockTimeout := cmd.Flags().Duration("lock-timeout", 30*time.Second, "the max time to wait to acquire a lock. 0 implies no timeout")
+ cmd.RunE = func(cmd *cobra.Command, args []string) error {
+ connConfig, err := connFlags.parseConnConfig()
+ if err != nil {
+ return err
+ }
+
+ planConfig, err := planFlags.parsePlanConfig()
+ if err != nil {
+ return err
+ }
+
+ if *lockTimeout < 0 {
+ return errors.New("lock timeout must be >= 0")
+ }
+
+ cmd.SilenceUsage = true
+
+ plan, err := generatePlan(context.Background(), log.SimpleLogger(), connConfig, planConfig)
+ if err != nil {
+ return err
+ } else if len(plan.Statements) == 0 {
+ fmt.Println("Schema matches expected. No plan generated")
+ return nil
+ }
+
+ fmt.Println(header("Review plan"))
+ fmt.Print(planToPrettyS(plan), "\n\n")
+
+ if err := failIfHazardsNotAllowed(plan, *allowedHazardsTypesStrs); err != nil {
+ return err
+ }
+ if err := mustContinuePrompt(
+ fmt.Sprintf(
+ "Apply migration with the following hazards: %s?",
+ strings.Join(*allowedHazardsTypesStrs, ", "),
+ ),
+ ); err != nil {
+ return err
+ }
+
+ if err := runPlan(context.Background(), connConfig, plan, lockTimeout); err != nil {
+ return err
+ }
+ fmt.Println("Schema applied successfully")
+ return nil
+ }
+
+ return cmd
+}
+
+func failIfHazardsNotAllowed(plan diff.Plan, allowedHazardsTypesStrs []string) error {
+ isAllowedByHazardType := make(map[diff.MigrationHazardType]bool)
+ for _, val := range allowedHazardsTypesStrs {
+ isAllowedByHazardType[strings.ToUpper(val)] = true
+ }
+ var disallowedHazardMsgs []string
+ for i, stmt := range plan.Statements {
+ var disallowedTypes []diff.MigrationHazardType
+ for _, hzd := range stmt.Hazards {
+ if !isAllowedByHazardType[hzd.Type] {
+ disallowedTypes = append(disallowedTypes, hzd.Type)
+ }
+ }
+ if len(disallowedTypes) > 0 {
+ disallowedHazardMsgs = append(disallowedHazardMsgs,
+ fmt.Sprintf("- Statement %d: %s", getDisplayableStmtIdx(i), strings.Join(disallowedTypes, ", ")),
+ )
+ }
+
+ }
+ if len(disallowedHazardMsgs) > 0 {
+ return errors.New(fmt.Sprintf(
+ "Prohited hazards found\n"+
+ "These hazards must be allowed via the allow-hazards flag, e.g., --allow-hazards %s\n"+
+ "Prohibited hazards in the following statements:\n%s",
+ strings.Join(getHazardTypes(plan), ","),
+ strings.Join(disallowedHazardMsgs, "\n"),
+ ))
+ }
+ return nil
+}
+
+func runPlan(ctx context.Context, connConfig *pgx.ConnConfig, plan diff.Plan, lockTimeout *time.Duration) error {
+ connPool, err := openDbWithPgxConfig(connConfig)
+ if err != nil {
+ return err
+ }
+ defer connPool.Close()
+
+ conn, err := connPool.Conn(ctx)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ _, err = conn.ExecContext(ctx, fmt.Sprintf("SET SESSION lock_timeout = %d", lockTimeout.Milliseconds()))
+ if err != nil {
+ return fmt.Errorf("setting lock timeout: %w", err)
+ }
+
+ // Due to the way *sql.Db works, when a statement_timeout is set for the session, it will NOT reset
+ // by default when it's returned to the pool.
+ //
+ // We can't set the timeout at the TRANSACTION-level (for each transaction) because `ADD INDEX CONCURRENTLY`
+ // must be executed within its own transaction block. Postgres will error if you try to set a TRANSACTION-level
+ // timeout for it. SESSION-level statement_timeouts are respected by `ADD INDEX CONCURRENTLY`
+ for i, stmt := range plan.Statements {
+ fmt.Println(header(fmt.Sprintf("Executing statement %d", getDisplayableStmtIdx(i))))
+ fmt.Printf("%s\n\n", statementToPrettyS(stmt))
+ start := time.Now()
+ if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION statement_timeout = %d", stmt.Timeout.Milliseconds())); err != nil {
+ return fmt.Errorf("setting statement timeout: %w", err)
+ }
+ if _, err := conn.ExecContext(ctx, stmt.ToSQL()); err != nil {
+ // could the migration statement contain sensitive information?
+ return fmt.Errorf("executing migration statement. the database maybe be in a dirty state: %s: %w", stmt, err)
+ }
+ fmt.Printf("Finished executing statement. Duration: %s\n", time.Since(start))
+ }
+ fmt.Println(header("Complete"))
+
+ return nil
+}
+
+func getHazardTypes(plan diff.Plan) []diff.MigrationHazardType {
+ seenHazardTypes := make(map[diff.MigrationHazardType]bool)
+ var hazardTypes []diff.MigrationHazardType
+ for _, stmt := range plan.Statements {
+ for _, hazard := range stmt.Hazards {
+ if !seenHazardTypes[hazard.Type] {
+ seenHazardTypes[hazard.Type] = true
+ hazardTypes = append(hazardTypes, hazard.Type)
+ }
+ }
+ }
+ sort.Slice(hazardTypes, func(i, j int) bool {
+ return hazardTypes[i] < hazardTypes[j]
+ })
+ return hazardTypes
+}
+
+func getDisplayableStmtIdx(i int) int {
+ return i + 1
+}
diff --git a/cmd/pg-schema-diff/cli.go b/cmd/pg-schema-diff/cli.go
new file mode 100644
index 0000000..edd447d
--- /dev/null
+++ b/cmd/pg-schema-diff/cli.go
@@ -0,0 +1,46 @@
+package main
+
+import (
+ "fmt"
+ "math"
+ "strings"
+
+ "github.com/manifoldco/promptui"
+)
+
+func header(header string) string {
+ const headerTargetWidth = 80
+
+ if len(header) > headerTargetWidth {
+ return header
+ }
+
+ if len(header) > 0 {
+ header = fmt.Sprintf(" %s ", header)
+ }
+ hashTagsOnSide := int(math.Ceil(float64(headerTargetWidth-len(header)) / 2))
+
+ rightHashTags := strings.Repeat("#", hashTagsOnSide)
+ leftHashTags := rightHashTags
+ if headerTargetWidth-len(header)-2*hashTagsOnSide > 0 {
+ leftHashTags += "#"
+ }
+ return fmt.Sprintf("%s%s%s", leftHashTags, header, rightHashTags)
+}
+
+// MustContinuePrompt prompts the user if they want to continue, and returns an error otherwise.
+// promptui requires the ContinueLabel to be one line
+func mustContinuePrompt(continueLabel string) error {
+ if len(continueLabel) == 0 {
+ continueLabel = "Continue?"
+ }
+ if _, result, err := (&promptui.Select{
+ Label: continueLabel,
+ Items: []string{"No", "Yes"},
+ }).Run(); err != nil {
+ return err
+ } else if result == "No" {
+ return fmt.Errorf("user aborted")
+ }
+ return nil
+}
diff --git a/cmd/pg-schema-diff/flags.go b/cmd/pg-schema-diff/flags.go
new file mode 100644
index 0000000..9dfe478
--- /dev/null
+++ b/cmd/pg-schema-diff/flags.go
@@ -0,0 +1,42 @@
+package main
+
+import (
+ "os"
+
+ "github.com/jackc/pgx/v4"
+ "github.com/spf13/cobra"
+)
+
+type connFlags struct {
+ dsn *string
+}
+
+func createConnFlags(cmd *cobra.Command) connFlags {
+ schemaDir := cmd.Flags().String("dsn", "", "Connection string for the database (DB password can be specified through PGPASSWORD environment variable)")
+ mustMarkFlagAsRequired(cmd, "dsn")
+
+ return connFlags{
+ dsn: schemaDir,
+ }
+}
+
+func (c connFlags) parseConnConfig() (*pgx.ConnConfig, error) {
+ config, err := pgx.ParseConfig(*c.dsn)
+ if err != nil {
+ return nil, err
+ }
+
+ if config.Password == "" {
+ if pgPassword := os.Getenv("PGPASSWORD"); pgPassword != "" {
+ config.Password = pgPassword
+ }
+ }
+
+ return config, nil
+}
+
+func mustMarkFlagAsRequired(cmd *cobra.Command, flagName string) {
+ if err := cmd.MarkFlagRequired(flagName); err != nil {
+ panic(err)
+ }
+}
diff --git a/cmd/pg-schema-diff/main.go b/cmd/pg-schema-diff/main.go
new file mode 100644
index 0000000..1e92303
--- /dev/null
+++ b/cmd/pg-schema-diff/main.go
@@ -0,0 +1,25 @@
+package main
+
+import (
+ "os"
+
+ "github.com/spf13/cobra"
+)
+
+// rootCmd represents the base command when called without any subcommands
+var rootCmd = &cobra.Command{
+ Use: "pg-schema-diff",
+ Short: "Diff two Postgres schemas and generate the SQL to get from one to the other",
+}
+
+func init() {
+ rootCmd.AddCommand(buildPlanCmd())
+ rootCmd.AddCommand(buildApplyCmd())
+}
+
+func main() {
+ err := rootCmd.Execute()
+ if err != nil {
+ os.Exit(1)
+ }
+}
diff --git a/cmd/pg-schema-diff/plan_cmd.go b/cmd/pg-schema-diff/plan_cmd.go
new file mode 100644
index 0000000..5e3511b
--- /dev/null
+++ b/cmd/pg-schema-diff/plan_cmd.go
@@ -0,0 +1,331 @@
+package main
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/jackc/pgx/v4"
+ "github.com/spf13/cobra"
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+ "github.com/stripe/pg-schema-diff/pkg/log"
+ "github.com/stripe/pg-schema-diff/pkg/tempdb"
+)
+
+var (
+ // Match arguments in the format "regex=duration" where duration is any duration valid in time.ParseDuration
+ // We'll let time.ParseDuration handle the complexity of parsing invalid duration, so the regex we're extracting is
+ // all characters greedily up to the rightmost "="
+ statementTimeoutModifierRegex = regexp.MustCompile(`^(?P.+)=(?P.+)$`)
+ regexSTMRegexIndex = statementTimeoutModifierRegex.SubexpIndex("regex")
+ durationSTMRegexIndex = statementTimeoutModifierRegex.SubexpIndex("duration")
+
+ // Match arguments in the format "index duration:statement" where duration is any duration valid in
+ // time.ParseDuration. In order to prevent matching on ":" in the duration, limit the character to just letters
+ // and numbers. To keep the regex simple, we won't bother matching on a more specific pattern for durations.
+ // time.ParseDuration can handle the complexity of parsing invalid durations
+ insertStatementRegex = regexp.MustCompile(`^(?P\d+) (?P[a-zA-Z0-9\.]+):(?P.+?);?$`)
+ indexInsertStatementRegexIndex = insertStatementRegex.SubexpIndex("index")
+ durationInsertStatementRegexIndex = insertStatementRegex.SubexpIndex("duration")
+ ddlInsertStatementRegexIndex = insertStatementRegex.SubexpIndex("ddl")
+)
+
+func buildPlanCmd() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "plan",
+ Aliases: []string{"diff"},
+ Short: "Generate the diff between two databases and the SQL to get from one to the other",
+ }
+
+ connFlags := createConnFlags(cmd)
+ planFlags := createPlanFlags(cmd)
+ cmd.RunE = func(cmd *cobra.Command, args []string) error {
+ connConfig, err := connFlags.parseConnConfig()
+ if err != nil {
+ return err
+ }
+
+ planConfig, err := planFlags.parsePlanConfig()
+ if err != nil {
+ return err
+ }
+
+ cmd.SilenceUsage = true
+
+ plan, err := generatePlan(context.Background(), log.SimpleLogger(), connConfig, planConfig)
+ if err != nil {
+ return err
+ } else if len(plan.Statements) == 0 {
+ fmt.Println("Schema matches expected. No plan generated")
+ return nil
+ }
+ fmt.Printf("\n%s\n", header("Generated plan"))
+ fmt.Println(planToPrettyS(plan))
+ return nil
+ }
+
+ return cmd
+}
+
+type (
+ planFlags struct {
+ schemaDir *string
+ statementTimeoutModifiers *[]string
+ insertStatements *[]string
+ }
+
+ statementTimeoutModifier struct {
+ regex *regexp.Regexp
+ timeout time.Duration
+ }
+
+ insertStatement struct {
+ ddl string
+ index int
+ timeout time.Duration
+ }
+
+ planConfig struct {
+ schemaDir string
+ statementTimeoutModifiers []statementTimeoutModifier
+ insertStatements []insertStatement
+ }
+)
+
+func createPlanFlags(cmd *cobra.Command) planFlags {
+ schemaDir := cmd.Flags().String("schema-dir", "", "Directory containing schema files")
+ mustMarkFlagAsRequired(cmd, "schema-dir")
+
+ statementTimeoutModifiers := cmd.Flags().StringArrayP("statement-timeout-modifier", "t", nil,
+ "regex=timeout key-value pairs, where if a statement matches the regex, the statement will have the target"+
+ " timeout. If multiple regexes match, the latest regex will take priority. Example: -t 'CREATE TABLE=5m' -t 'CONCURRENTLY=10s'")
+ insertStatements := cmd.Flags().StringArrayP("insert-statement", "s", nil,
+ "_: values. Will insert the statement at the index in the "+
+ "generated plan with the specified timeout. This follows normal insert semantics. Example: -s '0 5s:SELECT 1''")
+
+ return planFlags{
+ schemaDir: schemaDir,
+ statementTimeoutModifiers: statementTimeoutModifiers,
+ insertStatements: insertStatements,
+ }
+}
+
+func (p planFlags) parsePlanConfig() (planConfig, error) {
+ var statementTimeoutModifiers []statementTimeoutModifier
+ for _, s := range *p.statementTimeoutModifiers {
+ stm, err := parseStatementTimeoutModifierStr(s)
+ if err != nil {
+ return planConfig{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err)
+ }
+ statementTimeoutModifiers = append(statementTimeoutModifiers, stm)
+ }
+
+ var insertStatements []insertStatement
+ for _, i := range *p.insertStatements {
+ is, err := parseInsertStatementStr(i)
+ if err != nil {
+ return planConfig{}, fmt.Errorf("parsing insert statement from %q: %w", i, err)
+ }
+ insertStatements = append(insertStatements, is)
+ }
+
+ return planConfig{
+ schemaDir: *p.schemaDir,
+ statementTimeoutModifiers: statementTimeoutModifiers,
+ insertStatements: insertStatements,
+ }, nil
+}
+
+func parseStatementTimeoutModifierStr(val string) (statementTimeoutModifier, error) {
+ submatches := statementTimeoutModifierRegex.FindStringSubmatch(val)
+ if len(submatches) <= regexSTMRegexIndex || len(submatches) <= durationSTMRegexIndex {
+ return statementTimeoutModifier{}, fmt.Errorf("could not parse regex and duration from arg. expected to be in the format of " +
+ "'Some.*Regex='. Example durations include: 2s, 5m, 10.5h")
+ }
+ regexStr := submatches[regexSTMRegexIndex]
+ durationStr := submatches[durationSTMRegexIndex]
+
+ regex, err := regexp.Compile(regexStr)
+ if err != nil {
+ return statementTimeoutModifier{}, fmt.Errorf("regex could not be compiled from %q: %w", regexStr, err)
+ }
+
+ duration, err := time.ParseDuration(durationStr)
+ if err != nil {
+ return statementTimeoutModifier{}, fmt.Errorf("duration could not be parsed from %q: %w", durationStr, err)
+ }
+
+ return statementTimeoutModifier{
+ regex: regex,
+ timeout: duration,
+ }, nil
+}
+
+func parseInsertStatementStr(val string) (insertStatement, error) {
+ submatches := insertStatementRegex.FindStringSubmatch(val)
+ if len(submatches) <= indexInsertStatementRegexIndex ||
+ len(submatches) <= durationInsertStatementRegexIndex ||
+ len(submatches) <= ddlInsertStatementRegexIndex {
+ return insertStatement{}, fmt.Errorf("could not parse index, duration, and statement from arg. expected to be in the " +
+ "format of ' :'. Example durations include: 2s, 5m, 10.5h")
+ }
+ indexStr := submatches[indexInsertStatementRegexIndex]
+ index, err := strconv.Atoi(indexStr)
+ if err != nil {
+ return insertStatement{}, fmt.Errorf("could not parse index (an int) from \"%q\"", indexStr)
+ }
+
+ durationStr := submatches[durationInsertStatementRegexIndex]
+ duration, err := time.ParseDuration(durationStr)
+ if err != nil {
+ return insertStatement{}, fmt.Errorf("duration could not be parsed from \"%q\": %w", durationStr, err)
+ }
+
+ return insertStatement{
+ index: index,
+ ddl: submatches[ddlInsertStatementRegexIndex],
+ timeout: duration,
+ }, nil
+}
+
+func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnConfig, planConfig planConfig) (diff.Plan, error) {
+ ddl, err := getDDLFromPath(planConfig.schemaDir)
+ if err != nil {
+ return diff.Plan{}, nil
+ }
+
+ tempDbFactory, err := tempdb.NewOnInstanceFactory(ctx, func(ctx context.Context, dbName string) (*sql.DB, error) {
+ copiedConfig := connConfig.Copy()
+ copiedConfig.Database = dbName
+ return openDbWithPgxConfig(copiedConfig)
+ })
+ if err != nil {
+ return diff.Plan{}, err
+ }
+ defer func() {
+ err := tempDbFactory.Close()
+ if err != nil {
+ logger.Errorf("error shutting down temp db factory: %v", err)
+ }
+ }()
+
+ connPool, err := openDbWithPgxConfig(connConfig)
+ if err != nil {
+ return diff.Plan{}, err
+ }
+ defer connPool.Close()
+
+ conn, err := connPool.Conn(ctx)
+ if err != nil {
+ return diff.Plan{}, err
+ }
+ defer conn.Close()
+
+ plan, err := diff.GeneratePlan(ctx, conn, tempDbFactory, ddl,
+ diff.WithDataPackNewTables(),
+ diff.WithIgnoreChangesToColOrder(),
+ )
+ if err != nil {
+ return diff.Plan{}, fmt.Errorf("generating plan: %w", err)
+ }
+
+ modifiedPlan, err := applyPlanModifiers(plan, planConfig.statementTimeoutModifiers, planConfig.insertStatements)
+ if err != nil {
+ return diff.Plan{}, fmt.Errorf("applying plan modifiers: %w", err)
+ }
+
+ return modifiedPlan, nil
+}
+
+func applyPlanModifiers(
+ plan diff.Plan,
+ statementTimeoutModifiers []statementTimeoutModifier,
+ insertStatements []insertStatement,
+) (diff.Plan, error) {
+ for _, stm := range statementTimeoutModifiers {
+ plan = plan.ApplyStatementTimeoutModifier(stm.regex, stm.timeout)
+ }
+ for _, is := range insertStatements {
+ var err error
+ plan, err = plan.InsertStatement(is.index, diff.Statement{
+ DDL: is.ddl,
+ Timeout: is.timeout,
+ Hazards: []diff.MigrationHazard{{
+ Type: diff.MigrationHazardTypeIsUserGenerated,
+ Message: "This statement is user-generated",
+ }},
+ })
+ if err != nil {
+ return diff.Plan{}, fmt.Errorf("inserting statement %q with timeout %s at index %d: %w",
+ is.ddl, is.timeout, is.index, err)
+ }
+ }
+ return plan, nil
+}
+
+func getDDLFromPath(path string) ([]string, error) {
+ fileEntries, err := os.ReadDir(path)
+ if err != nil {
+ return nil, err
+ }
+ var ddl []string
+ for _, entry := range fileEntries {
+ if filepath.Ext(entry.Name()) == ".sql" {
+ if stmts, err := os.ReadFile(filepath.Join(path, entry.Name())); err != nil {
+ return nil, err
+ } else {
+ ddl = append(ddl, string(stmts))
+ }
+ }
+ }
+ return ddl, nil
+}
+
+func planToPrettyS(plan diff.Plan) string {
+ sb := strings.Builder{}
+
+ // We are going to put a statement index before each statement. To do that,
+ // we need to find how many characters are in the largest index, so we can provide the appropriate amount
+ // of padding before the statements to align all of them
+ // E.g.
+ // 1. ALTER TABLE foobar ADD COLUMN foo BIGINT
+ // ....
+ // 22. ADD INDEX some_idx ON some_other_table(some_column)
+ stmtNumPadding := len(strconv.Itoa(len(plan.Statements))) // find how much padding is required for the statement index
+ fmtString := fmt.Sprintf("%%0%dd. %%s", stmtNumPadding) // supply custom padding
+
+ var stmtStrs []string
+ for i, stmt := range plan.Statements {
+ stmtStr := fmt.Sprintf(fmtString, getDisplayableStmtIdx(i), statementToPrettyS(stmt))
+ stmtStrs = append(stmtStrs, stmtStr)
+ }
+ sb.WriteString(strings.Join(stmtStrs, "\n\n"))
+
+ return sb.String()
+}
+
+func statementToPrettyS(stmt diff.Statement) string {
+ sb := strings.Builder{}
+ sb.WriteString(fmt.Sprintf("%s;", stmt.DDL))
+ sb.WriteString(fmt.Sprintf("\n\t-- Timeout: %s", stmt.Timeout))
+ if len(stmt.Hazards) > 0 {
+ for _, hazard := range stmt.Hazards {
+ sb.WriteString(fmt.Sprintf("\n\t-- Hazard %s", hazardToPrettyS(hazard)))
+ }
+ }
+ return sb.String()
+}
+
+func hazardToPrettyS(hazard diff.MigrationHazard) string {
+ if len(hazard.Message) > 0 {
+ return fmt.Sprintf("%s: %s", hazard.Type, hazard.Message)
+ } else {
+ return hazard.Type
+ }
+}
diff --git a/cmd/pg-schema-diff/plan_cmd_test.go b/cmd/pg-schema-diff/plan_cmd_test.go
new file mode 100644
index 0000000..3908bfe
--- /dev/null
+++ b/cmd/pg-schema-diff/plan_cmd_test.go
@@ -0,0 +1,119 @@
+package main
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseStatementTimeoutModifierStr(t *testing.T) {
+ for _, tc := range []struct {
+ opt string `explicit:"always"`
+
+ expectedRegexStr string
+ expectedTimeout time.Duration
+ expectedErrContains string
+ }{
+ {
+ opt: "normal duration=5m",
+ expectedRegexStr: "normal duration",
+ expectedTimeout: 5 * time.Minute,
+ },
+ {
+ opt: "some regex with a duration ending in a period=5.h",
+ expectedRegexStr: "some regex with a duration ending in a period",
+ expectedTimeout: 5 * time.Hour,
+ },
+ {
+ opt: " starts with spaces than has a *=5.5m",
+ expectedRegexStr: " starts with spaces than has a *",
+ expectedTimeout: time.Minute*5 + 30*time.Second,
+ },
+ {
+ opt: "has a valid opt in the regex something=5.5m in the regex =15s",
+ expectedRegexStr: "has a valid opt in the regex something=5.5m in the regex ",
+ expectedTimeout: 15 * time.Second,
+ },
+ {
+ opt: "has multiple valid opts opt=15m5s in the regex something=5.5m in the regex and has compound duration=15m1ms2us10ns",
+ expectedRegexStr: "has multiple valid opts opt=15m5s in the regex something=5.5m in the regex and has compound duration",
+ expectedTimeout: 15*time.Minute + 1*time.Millisecond + 2*time.Microsecond + 10*time.Nanosecond,
+ },
+ {
+ opt: "=5m",
+ expectedErrContains: "could not parse regex and duration from arg",
+ },
+ {
+ opt: "15m",
+ expectedErrContains: "could not parse regex and duration from arg",
+ },
+ {
+ opt: "someregex;15m",
+ expectedErrContains: "could not parse regex and duration from arg",
+ },
+ {
+ opt: "someregex=invalid duration5s",
+ expectedErrContains: "duration could not be parsed",
+ },
+ } {
+ t.Run(tc.opt, func(t *testing.T) {
+ modifier, err := parseStatementTimeoutModifierStr(tc.opt)
+ if len(tc.expectedErrContains) > 0 {
+ assert.ErrorContains(t, err, tc.expectedErrContains)
+ return
+ }
+ require.NoError(t, err)
+ assert.Equal(t, tc.expectedRegexStr, modifier.regex.String())
+ assert.Equal(t, tc.expectedTimeout, modifier.timeout)
+ })
+ }
+}
+
+func TestParseInsertStatementStr(t *testing.T) {
+ for _, tc := range []struct {
+ opt string `explicit:"always"`
+ expectedInsertStmt insertStatement
+ expectedErrContains string
+ }{
+ {
+ opt: "1 0h5.1m:SELECT * FROM :TABLE:0_5m:something",
+ expectedInsertStmt: insertStatement{
+ index: 1,
+ ddl: "SELECT * FROM :TABLE:0_5m:something",
+ timeout: 5*time.Minute + 6*time.Second,
+ },
+ },
+ {
+ opt: "0 100ms:SELECT 1; SELECT * FROM something;",
+ expectedInsertStmt: insertStatement{
+ index: 0,
+ ddl: "SELECT 1; SELECT * FROM something",
+ timeout: 100 * time.Millisecond,
+ },
+ },
+ {
+ opt: " 5s:No index",
+ expectedErrContains: "could not parse",
+ },
+ {
+ opt: "0 5g:Invalid duration",
+ expectedErrContains: "duration could not be parsed",
+ },
+ {
+ opt: "0 5s:",
+ expectedErrContains: "could not parse",
+ },
+ } {
+ t.Run(tc.opt, func(t *testing.T) {
+ insertStatement, err := parseInsertStatementStr(tc.opt)
+ if len(tc.expectedErrContains) > 0 {
+ assert.ErrorContains(t, err, tc.expectedErrContains)
+ return
+ }
+ require.NoError(t, err)
+ assert.Equal(t, tc.expectedInsertStmt, insertStatement)
+ })
+ }
+}
diff --git a/cmd/pg-schema-diff/sql.go b/cmd/pg-schema-diff/sql.go
new file mode 100644
index 0000000..7d13533
--- /dev/null
+++ b/cmd/pg-schema-diff/sql.go
@@ -0,0 +1,18 @@
+package main
+
+import (
+ "database/sql"
+
+ "github.com/jackc/pgx/v4"
+ "github.com/jackc/pgx/v4/stdlib"
+)
+
+// openDbWithPgxConfig opens a database connection using the provided pgx.ConnConfig and pings it
+func openDbWithPgxConfig(config *pgx.ConnConfig) (*sql.DB, error) {
+ connPool := stdlib.OpenDB(*config)
+ if err := connPool.Ping(); err != nil {
+ connPool.Close()
+ return nil, err
+ }
+ return connPool, nil
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..373915d
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,35 @@
+module github.com/stripe/pg-schema-diff
+
+go 1.18
+
+require (
+ github.com/google/go-cmp v0.5.9
+ github.com/google/uuid v1.3.0
+ github.com/jackc/pgx/v4 v4.14.0
+ github.com/kr/pretty v0.3.1
+ github.com/mitchellh/hashstructure/v2 v2.0.2
+ github.com/stretchr/testify v1.8.2
+)
+
+require (
+ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/inconshreveable/mousetrap v1.1.0 // indirect
+ github.com/jackc/chunkreader/v2 v2.0.1 // indirect
+ github.com/jackc/pgconn v1.14.0 // indirect
+ github.com/jackc/pgio v1.0.0 // indirect
+ github.com/jackc/pgpassfile v1.0.0 // indirect
+ github.com/jackc/pgproto3/v2 v2.3.2 // indirect
+ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
+ github.com/jackc/pgtype v1.14.0 // indirect
+ github.com/kr/text v0.2.0 // indirect
+ github.com/manifoldco/promptui v0.9.0 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ github.com/rogpeppe/go-internal v1.9.0 // indirect
+ github.com/spf13/cobra v1.7.0 // indirect
+ github.com/spf13/pflag v1.0.5 // indirect
+ golang.org/x/crypto v0.6.0 // indirect
+ golang.org/x/sys v0.5.0 // indirect
+ golang.org/x/text v0.7.0 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..635fea5
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,234 @@
+github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
+github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc=
+github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
+github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
+github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8=
+github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
+github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
+github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
+github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
+github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
+github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
+github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
+github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
+github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
+github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
+github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
+github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
+github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
+github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
+github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
+github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
+github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
+github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
+github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
+github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
+github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
+github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
+github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
+github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
+github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s=
+github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o=
+github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY=
+github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI=
+github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI=
+github.com/jackc/pgconn v1.14.0 h1:vrbA9Ud87g6JdFWkHTJXppVce58qPIdP7N8y0Ml/A7Q=
+github.com/jackc/pgconn v1.14.0/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E=
+github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
+github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
+github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
+github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c=
+github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc=
+github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak=
+github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
+github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
+github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
+github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
+github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
+github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
+github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
+github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
+github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
+github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
+github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0=
+github.com/jackc/pgproto3/v2 v2.3.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
+github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
+github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
+github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
+github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
+github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
+github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
+github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM=
+github.com/jackc/pgtype v1.9.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4=
+github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw=
+github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4=
+github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
+github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
+github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
+github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs=
+github.com/jackc/pgx/v4 v4.14.0 h1:TgdrmgnM7VY72EuSQzBbBd4JA1RLqJolrw9nQVZABVc=
+github.com/jackc/pgx/v4 v4.14.0/go.mod h1:jT3ibf/A0ZVCp89rtCIN0zCJxcE74ypROmHEZYsG/j8=
+github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
+github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
+github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
+github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
+github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
+github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
+github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
+github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
+github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
+github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
+github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
+github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
+github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
+github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
+github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
+github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
+github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8=
+github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA=
+github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg=
+github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
+github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
+github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
+github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
+github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
+github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
+github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
+github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
+github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
+github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
+github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
+github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
+github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
+github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
+github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
+github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
+github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
+github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
+github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ=
+github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
+github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
+github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
+github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
+github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
+github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
+github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
+github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
+github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
+github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
+go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
+go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
+go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
+go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
+go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
+go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4=
+go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU=
+go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
+go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
+go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
+go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
+golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
+golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
+golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
+golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
+golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
+golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
+golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
+golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
+golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
+golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
+golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
+golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
+golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
+golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
+golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
+golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
+golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
+golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
+gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
+gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
diff --git a/internal/graph/graph.go b/internal/graph/graph.go
new file mode 100644
index 0000000..4d26620
--- /dev/null
+++ b/internal/graph/graph.go
@@ -0,0 +1,215 @@
+package graph
+
+import (
+ "fmt"
+ "sort"
+)
+
+type Vertex interface {
+ GetId() string
+}
+
+type AdjacencyMatrix map[string]map[string]bool
+
+// Graph is a directed graph
+type Graph[V Vertex] struct {
+ verticesById map[string]V
+ edges AdjacencyMatrix
+}
+
+func NewGraph[V Vertex]() *Graph[V] {
+ return &Graph[V]{
+ verticesById: make(map[string]V),
+ edges: make(AdjacencyMatrix),
+ }
+}
+
+// AddVertex adds a vertex to the graph.
+// If the vertex already exists, it will override it and keep the edges
+func (g *Graph[V]) AddVertex(v V) {
+ g.verticesById[v.GetId()] = v
+ if g.edges[v.GetId()] == nil {
+ g.edges[v.GetId()] = make(map[string]bool)
+ }
+}
+
+// AddEdge adds an edge to the graph. If the vertex doesn't exist, it will error
+func (g *Graph[V]) AddEdge(sourceId, targetId string) error {
+ if !g.HasVertexWithId(sourceId) {
+ return fmt.Errorf("source %s does not exist", sourceId)
+ }
+ if !g.HasVertexWithId(targetId) {
+ return fmt.Errorf("target %s does not exist", targetId)
+ }
+ g.edges[sourceId][targetId] = true
+
+ return nil
+}
+
+// Union unions the graph with a new graph. If a vertex exists in both graphs,
+// it uses the merge function to determine what the new vertex is
+func (g *Graph[V]) Union(new *Graph[V], merge func(old, new V) V) error {
+ for _, newV := range new.verticesById {
+ if g.HasVertexWithId(newV.GetId()) {
+ id := newV.GetId()
+ // merge the vertices using the procedure defined by the user
+ newV = merge(g.GetVertex(newV.GetId()), newV)
+ if newV.GetId() != id {
+ return fmt.Errorf("the merge function must return a vertex with the same id: "+
+ "expected %s but found %s", id, newV.GetId())
+ }
+ }
+ g.AddVertex(newV)
+ }
+
+ for source, adjacentEdgesMap := range new.edges {
+ for target, isAdjacent := range adjacentEdgesMap {
+ if isAdjacent {
+ if err := g.AddEdge(source, target); err != nil {
+ return fmt.Errorf("adding an edge from the new graph: %w", err)
+ }
+ }
+ }
+ }
+
+ return nil
+}
+
+func (g *Graph[V]) GetVertex(id string) V {
+ return g.verticesById[id]
+}
+
+func (g *Graph[V]) HasVertexWithId(id string) bool {
+ _, hasVertex := g.verticesById[id]
+ return hasVertex
+}
+
+// Reverse reverses the edges of the map. The sources become the sinks and vice versa.
+func (g *Graph[V]) Reverse() {
+ reversedEdges := make(AdjacencyMatrix)
+ for vertexId, _ := range g.verticesById {
+ reversedEdges[vertexId] = make(map[string]bool)
+ }
+ for source, adjacentEdgesMap := range g.edges {
+ for target, isAdjacent := range adjacentEdgesMap {
+ if isAdjacent {
+ reversedEdges[target][source] = true
+ }
+ }
+ }
+ g.edges = reversedEdges
+}
+
+func (g *Graph[V]) Copy() *Graph[V] {
+ verticesById := make(map[string]V)
+ for id, v := range g.verticesById {
+ verticesById[id] = v
+ }
+
+ edges := make(AdjacencyMatrix)
+ for source, adjacentEdgesMap := range g.edges {
+ edges[source] = make(map[string]bool)
+ for target, isAdjacent := range adjacentEdgesMap {
+ edges[source][target] = isAdjacent
+ }
+ }
+
+ return &Graph[V]{
+ verticesById: verticesById,
+ edges: edges,
+ }
+}
+
+func (g *Graph[V]) TopologicallySort() ([]V, error) {
+ return g.TopologicallySortWithPriority(func(_, _ V) bool {
+ return false
+ })
+}
+
+type Ordered interface {
+ ~int | ~string
+}
+
+func IsLowerPriorityFromGetPriority[V Vertex, P Ordered](getPriority func(V) P) func(V, V) bool {
+ return func(v1 V, v2 V) bool {
+ return getPriority(v1) < getPriority(v2)
+ }
+}
+
+// TopologicallySortWithPriority returns a consistent topological sort of the graph taking a greedy approach to put
+// high priority sources first. The output is deterministic. getPriority must be deterministic
+func (g *Graph[V]) TopologicallySortWithPriority(isLowerPriority func(V, V) bool) ([]V, error) {
+ // This uses mutation. Copy the graph
+ graph := g.Copy()
+
+ // The strategy of this algorithm:
+ // 1. Count the number of incoming edges to each vertex
+ // 2. Remove the sources and add them to the outputIds
+ // 3. Decrement the number of incoming edges to each vertex adjacent to the sink
+ // 4. Repeat 2-3 until the graph is empty
+
+ // The number of outgoing in the reversed graph is the number of incoming in the original graph
+ // In other words, a source in the graph (has no incoming edges) is a sink in the reversed graph
+ // (has no outgoing edges)
+
+ // To find the number of incoming edges in graph, just count the number of outgoing edges
+ // in the reversed graph
+ reversedGraph := graph.Copy()
+ reversedGraph.Reverse()
+ incomingEdgeCountByVertex := make(map[string]int)
+ for vertex, reversedAdjacentEdges := range reversedGraph.edges {
+ count := 0
+ for _, isAdjacent := range reversedAdjacentEdges {
+ if isAdjacent {
+ count++
+ }
+ }
+ incomingEdgeCountByVertex[vertex] = count
+ }
+
+ var output []V
+ // Add the sinks to the output. Delete the sinks. Repeat.
+ for len(graph.verticesById) > 0 {
+ // Put all the sources into an array, so we can get a stable sort of them before identifying the one
+ // with the highest priority
+ var sources []V
+ for sourceId, incomingEdgeCount := range incomingEdgeCountByVertex {
+ if incomingEdgeCount == 0 {
+ sources = append(sources, g.GetVertex(sourceId))
+ }
+ }
+ sort.Slice(sources, func(i, j int) bool {
+ return sources[i].GetId() < sources[j].GetId()
+ })
+
+ // Take the source with highest priority from the sorted array of sources
+ indexOfSourceWithHighestPri := -1
+ for i, source := range sources {
+ if indexOfSourceWithHighestPri == -1 || isLowerPriority(sources[indexOfSourceWithHighestPri], source) {
+ indexOfSourceWithHighestPri = i
+ }
+ }
+ if indexOfSourceWithHighestPri == -1 {
+ return nil, fmt.Errorf("cycle detected: %+v, %+v", graph, incomingEdgeCountByVertex)
+ }
+ sourceWithHighestPriority := sources[indexOfSourceWithHighestPri]
+
+ output = append(output, sourceWithHighestPriority)
+
+ // Remove source vertex from graph and update counts
+ // Update incoming edge counts
+ for target, hasEdge := range graph.edges[sourceWithHighestPriority.GetId()] {
+ if hasEdge {
+ incomingEdgeCountByVertex[target]--
+ }
+ }
+
+ // Delete the vertex from graph
+ delete(graph.verticesById, sourceWithHighestPriority.GetId())
+ // We don't need to worry about any incoming edges referencing this vertex, since it was a sink
+ delete(graph.edges, sourceWithHighestPriority.GetId())
+ delete(incomingEdgeCountByVertex, sourceWithHighestPriority.GetId())
+ }
+
+ return output, nil
+}
diff --git a/internal/graph/graph_test.go b/internal/graph/graph_test.go
new file mode 100644
index 0000000..de0a010
--- /dev/null
+++ b/internal/graph/graph_test.go
@@ -0,0 +1,432 @@
+package graph
+
+import (
+ "fmt"
+ "strconv"
+ "testing"
+
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAddGetHasOperations(t *testing.T) {
+ g := NewGraph[vertex]()
+
+ // get missing vertex return Zero value
+ assert.Zero(t, g.GetVertex("missing_vertex"))
+
+ // add vertex
+ v1 := NewV("v_1")
+ g.AddVertex(v1)
+ assert.True(t, g.HasVertexWithId("v_1"))
+ assert.Equal(t, v1, g.GetVertex("v_1"))
+
+ // override vertex
+ newV1 := NewV("v_1")
+ g.AddVertex(newV1)
+ assert.True(t, g.HasVertexWithId("v_1"))
+ assert.Equal(t, newV1, g.GetVertex("v_1"))
+ assert.NotEqual(t, v1, g.GetVertex("v_1"))
+
+ // add edge when target is missing
+ assert.Error(t, g.AddEdge("v_1", "missing_vertex"))
+
+ // add edge when source is missing
+ assert.Error(t, g.AddEdge("missing_vertex", "v_1"))
+
+ // add edge when both are present
+ v2 := NewV("v_2")
+ g.AddVertex(v2)
+ assert.True(t, g.HasVertexWithId("v_2"))
+ assert.NoError(t, g.AddEdge("v_1", "v_2"))
+ assert.Equal(t, AdjacencyMatrix{
+ "v_1": {
+ "v_2": true,
+ },
+ "v_2": {},
+ }, g.edges)
+
+ // overriding vertex keeps edges
+ g.AddVertex(NewV("v_1"))
+ g.AddVertex(NewV("v_2"))
+ assert.Equal(t, AdjacencyMatrix{
+ "v_1": {
+ "v_2": true,
+ },
+ "v_2": {},
+ }, g.edges)
+
+ // override edge
+ assert.NoError(t, g.AddEdge("v_1", "v_2"))
+ assert.Equal(t, AdjacencyMatrix{
+ "v_1": {
+ "v_2": true,
+ },
+ "v_2": {},
+ }, g.edges)
+
+ // allow cycles
+ assert.NoError(t, g.AddEdge("v_2", "v_1"))
+ assert.Equal(t, AdjacencyMatrix{
+ "v_1": {
+ "v_2": true,
+ },
+ "v_2": {
+ "v_1": true,
+ },
+ }, g.edges)
+}
+
+func TestReverse(t *testing.T) {
+ g := NewGraph[vertex]()
+ g.AddVertex(NewV("v_1"))
+ g.AddVertex(NewV("v_2"))
+ g.AddVertex(NewV("v_3"))
+ assert.NoError(t, g.AddEdge("v_1", "v_2"))
+ assert.NoError(t, g.AddEdge("v_1", "v_3"))
+ assert.NoError(t, g.AddEdge("v_2", "v_3"))
+ g.Reverse()
+ assert.Equal(t, AdjacencyMatrix{
+ "v_1": {},
+ "v_2": {
+ "v_1": true,
+ },
+ "v_3": {
+ "v_1": true,
+ "v_2": true,
+ },
+ }, g.edges)
+ assert.ElementsMatch(t, getVertexIds(g), []string{"v_1", "v_2", "v_3"})
+}
+
+func TestCopy(t *testing.T) {
+ g := NewGraph[vertex]()
+ g.AddVertex(NewV("shared_1"))
+ g.AddVertex(NewV("shared_2"))
+ g.AddVertex(NewV("shared_3"))
+ assert.NoError(t, g.AddEdge("shared_1", "shared_3"))
+ assert.NoError(t, g.AddEdge("shared_3", "shared_2"))
+
+ gC := g.Copy()
+ gC.AddVertex(NewV("g_copy_1"))
+ gC.AddVertex(NewV("g_copy_2"))
+ assert.NoError(t, gC.AddEdge("g_copy_1", "g_copy_2"))
+ assert.NoError(t, gC.AddEdge("shared_3", "g_copy_1"))
+ assert.NoError(t, gC.AddEdge("shared_3", "shared_1"))
+ copyOverrideShared1 := NewV("shared_1")
+ gC.AddVertex(copyOverrideShared1)
+
+ g.AddVertex(NewV("g_1"))
+ g.AddVertex(NewV("g_2"))
+ g.AddVertex(NewV("g_3"))
+ assert.NoError(t, g.AddEdge("g_3", "g_1"))
+ assert.NoError(t, g.AddEdge("g_2", "shared_2"))
+ assert.NoError(t, g.AddEdge("shared_2", "shared_3"))
+ originalOverrideShared2 := NewV("shared_2")
+ g.AddVertex(originalOverrideShared2)
+
+ // validate nodes on copy
+ assert.ElementsMatch(t, getVertexIds(gC), []string{
+ "shared_1", "shared_2", "shared_3", "g_copy_1", "g_copy_2",
+ })
+
+ // validate nodes on original
+ assert.ElementsMatch(t, getVertexIds(g), []string{
+ "shared_1", "shared_2", "shared_3", "g_1", "g_2", "g_3",
+ })
+
+ // validate overrides weren't copied over and non-overriden shared nodes are the same
+ assert.NotEqual(t, g.GetVertex("shared_1"), gC.GetVertex("shared_1"))
+ assert.Equal(t, gC.GetVertex("shared_1"), copyOverrideShared1)
+ assert.NotEqual(t, g.GetVertex("shared_2"), gC.GetVertex("shared_2"))
+ assert.Equal(t, g.GetVertex("shared_2"), originalOverrideShared2)
+ assert.Equal(t, g.GetVertex("shared_3"), gC.GetVertex("shared_3"))
+
+ // validate edges
+ assert.Equal(t, AdjacencyMatrix{
+ "shared_1": {
+ "shared_3": true,
+ },
+ "shared_2": {},
+ "shared_3": {
+ "shared_1": true,
+ "shared_2": true,
+ "g_copy_1": true,
+ },
+ "g_copy_1": {
+ "g_copy_2": true,
+ },
+ "g_copy_2": {},
+ }, gC.edges)
+ assert.Equal(t, AdjacencyMatrix{
+ "shared_1": {
+ "shared_3": true,
+ },
+ "shared_2": {
+ "shared_3": true,
+ },
+ "shared_3": {
+ "shared_2": true,
+ },
+ "g_1": {},
+ "g_2": {
+ "shared_2": true,
+ },
+ "g_3": {
+ "g_1": true,
+ },
+ }, g.edges)
+}
+
+func TestUnion(t *testing.T) {
+ gA := NewGraph[vertex]()
+ gA1 := NewV("a_1")
+ gA.AddVertex(gA1)
+ gA2 := NewV("a_2")
+ gA.AddVertex(gA2)
+ gA3 := NewV("a_3")
+ gA.AddVertex(gA3)
+ gAShared1 := NewV("shared_1")
+ gA.AddVertex(gAShared1)
+ gAShared2 := NewV("shared_2")
+ gA.AddVertex(gAShared2)
+ assert.NoError(t, gA.AddEdge("a_1", "a_2"))
+ assert.NoError(t, gA.AddEdge("a_3", "a_1"))
+ assert.NoError(t, gA.AddEdge("shared_1", "a_1"))
+
+ gB := NewGraph[vertex]()
+ gB1 := NewV("b_1")
+ gB.AddVertex(gB1)
+ gB2 := NewV("b_2")
+ gB.AddVertex(gB2)
+ gB3 := NewV("b_3")
+ gB.AddVertex(gB3)
+ gBShared1 := NewV("shared_1")
+ gB.AddVertex(gBShared1)
+ gBShared2 := NewV("shared_2")
+ gB.AddVertex(gBShared2)
+ assert.NoError(t, gB.AddEdge("b_3", "b_2"))
+ assert.NoError(t, gB.AddEdge("b_3", "b_1"))
+ assert.NoError(t, gB.AddEdge("shared_1", "b_2"))
+
+ err := gA.Union(gB, func(old, new vertex) vertex {
+ return vertex{
+ id: old.id,
+ val: fmt.Sprintf("%s_%s", old.val, new.val),
+ }
+ })
+ assert.NoError(t, err)
+
+ // make sure non-shared nodes were not merged
+ assert.Equal(t, gA1, gA.GetVertex("a_1"))
+ assert.Equal(t, gA2, gA.GetVertex("a_2"))
+ assert.Equal(t, gA3, gA.GetVertex("a_3"))
+ assert.Equal(t, gB1, gA.GetVertex("b_1"))
+ assert.Equal(t, gB2, gA.GetVertex("b_2"))
+ assert.Equal(t, gB3, gA.GetVertex("b_3"))
+
+ // check merged nodes
+ assert.Equal(t, vertex{
+ id: "shared_1",
+ val: fmt.Sprintf("%s_%s", gAShared1.val, gBShared1.val),
+ }, gA.GetVertex("shared_1"))
+ assert.Equal(t, vertex{
+ id: "shared_2",
+ val: fmt.Sprintf("%s_%s", gAShared2.val, gBShared2.val),
+ }, gA.GetVertex("shared_2"))
+
+ // no extra nodes
+ assert.ElementsMatch(t, getVertexIds(gA), []string{
+ "a_1", "a_2", "a_3", "b_1", "b_2", "b_3", "shared_1", "shared_2",
+ })
+
+ assert.Equal(t, AdjacencyMatrix{
+ "a_1": {
+ "a_2": true,
+ },
+ "a_2": {},
+ "a_3": {
+ "a_1": true,
+ },
+ "b_1": {},
+ "b_2": {},
+ "b_3": {
+ "b_1": true,
+ "b_2": true,
+ },
+ "shared_1": {
+ "a_1": true,
+ "b_2": true,
+ },
+ "shared_2": {},
+ }, gA.edges)
+}
+
+func TestUnionFailsIfMergeReturnsDifferentId(t *testing.T) {
+ gA := NewGraph[vertex]()
+ gA.AddVertex(NewV("shared_1"))
+ gB := NewGraph[vertex]()
+ gB.AddVertex(NewV("shared_1"))
+ assert.Error(t, gA.Union(gB, func(old, new vertex) vertex {
+ return vertex{
+ id: "different_id",
+ val: old.val,
+ }
+ }))
+}
+
+func TestTopologicallySort(t *testing.T) {
+ // Source: https://en.wikipedia.org/wiki/Topological_sorting#Examples
+ g := NewGraph[vertex]()
+ v5 := NewV("05")
+ g.AddVertex(v5)
+ v7 := NewV("07")
+ g.AddVertex(v7)
+ v3 := NewV("03")
+ g.AddVertex(v3)
+ v11 := NewV("11")
+ g.AddVertex(v11)
+ v8 := NewV("08")
+ g.AddVertex(v8)
+ v2 := NewV("02")
+ g.AddVertex(v2)
+ v9 := NewV("09")
+ g.AddVertex(v9)
+ v10 := NewV("10")
+ g.AddVertex(v10)
+ assert.NoError(t, g.AddEdge("05", "11"))
+ assert.NoError(t, g.AddEdge("07", "11"))
+ assert.NoError(t, g.AddEdge("07", "08"))
+ assert.NoError(t, g.AddEdge("03", "08"))
+ assert.NoError(t, g.AddEdge("03", "10"))
+ assert.NoError(t, g.AddEdge("11", "02"))
+ assert.NoError(t, g.AddEdge("11", "09"))
+ assert.NoError(t, g.AddEdge("11", "10"))
+ assert.NoError(t, g.AddEdge("08", "09"))
+
+ orderedNodes, err := g.TopologicallySort()
+ assert.NoError(t, err)
+ assert.Equal(t, []vertex{
+ v3, v5, v7, v8, v11, v2, v9, v10,
+ }, orderedNodes)
+
+ // Cycle should error
+ assert.NoError(t, g.AddEdge("10", "07"))
+ _, err = g.TopologicallySort()
+ assert.Error(t, err)
+}
+
+func TestTopologicallySortWithPriority(t *testing.T) {
+ // Source: https://en.wikipedia.org/wiki/Topological_sorting#Examples
+ g := NewGraph[vertex]()
+ v5 := NewV("05")
+ g.AddVertex(v5)
+ v7 := NewV("07")
+ g.AddVertex(v7)
+ v3 := NewV("03")
+ g.AddVertex(v3)
+ v11 := NewV("11")
+ g.AddVertex(v11)
+ v8 := NewV("08")
+ g.AddVertex(v8)
+ v2 := NewV("02")
+ g.AddVertex(v2)
+ v9 := NewV("09")
+ g.AddVertex(v9)
+ v10 := NewV("10")
+ g.AddVertex(v10)
+ assert.NoError(t, g.AddEdge("05", "11"))
+ assert.NoError(t, g.AddEdge("07", "11"))
+ assert.NoError(t, g.AddEdge("07", "08"))
+ assert.NoError(t, g.AddEdge("03", "08"))
+ assert.NoError(t, g.AddEdge("03", "10"))
+ assert.NoError(t, g.AddEdge("11", "02"))
+ assert.NoError(t, g.AddEdge("11", "09"))
+ assert.NoError(t, g.AddEdge("11", "10"))
+ assert.NoError(t, g.AddEdge("08", "09"))
+
+ for _, tc := range []struct {
+ name string
+ isLowerPriority func(v1, v2 vertex) bool
+ expectedOrdering []vertex
+ }{
+ {
+ name: "largest-numbered available vertex first (string-based GetPriority)",
+ isLowerPriority: IsLowerPriorityFromGetPriority(func(v vertex) string {
+ return v.GetId()
+ }),
+ expectedOrdering: []vertex{v7, v5, v11, v3, v10, v8, v9, v2},
+ },
+ {
+ name: "smallest-numbered available vertex first (numeric-based GetPriority)",
+ isLowerPriority: IsLowerPriorityFromGetPriority(func(v vertex) int {
+ idAsInt, err := strconv.Atoi(v.GetId())
+ require.NoError(t, err)
+ return -idAsInt
+ }),
+ expectedOrdering: []vertex{v3, v5, v7, v8, v11, v2, v9, v10},
+ },
+ {
+ name: "fewest edges first (prioritize high id's for tie breakers)",
+ isLowerPriority: func(v1, v2 vertex) bool {
+ v1EdgeCount := getEdgeCount(g, v1)
+ v2EdgeCount := getEdgeCount(g, v2)
+ if v1EdgeCount == v2EdgeCount {
+ // Break ties with ID
+ return v1.GetId() < v2.GetId()
+ }
+ return v1EdgeCount > v2EdgeCount
+ },
+ expectedOrdering: []vertex{v5, v7, v3, v8, v11, v10, v9, v2},
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+
+ // Highest number vertices should be prioritized first
+ orderedNodes, err := g.TopologicallySortWithPriority(tc.isLowerPriority)
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expectedOrdering, orderedNodes)
+ })
+ }
+
+ // Cycle should error
+ assert.NoError(t, g.AddEdge("10", "07"))
+ _, err := g.TopologicallySort()
+ assert.Error(t, err)
+}
+
+func getEdgeCount[V Vertex](g *Graph[V], v Vertex) int {
+ edgeCount := 0
+ for _, hasEdge := range g.edges[v.GetId()] {
+ if hasEdge {
+ edgeCount++
+ }
+ }
+ return edgeCount
+}
+
+type vertex struct {
+ id string
+ val string
+}
+
+func NewV(id string) vertex {
+ uuid, err := uuid.NewUUID()
+ if err != nil {
+ panic(err)
+ }
+ return vertex{id: id, val: uuid.String()}
+}
+
+func (v vertex) GetId() string {
+ return v.id
+}
+
+func getVertexIds(g *Graph[vertex]) []string {
+ var output []string
+ for id, _ := range g.verticesById {
+ output = append(output, id)
+ }
+ return output
+}
diff --git a/internal/migration_acceptance_tests/acceptance_test.go b/internal/migration_acceptance_tests/acceptance_test.go
new file mode 100644
index 0000000..940dab7
--- /dev/null
+++ b/internal/migration_acceptance_tests/acceptance_test.go
@@ -0,0 +1,206 @@
+package migration_acceptance_tests
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "testing"
+
+ _ "github.com/jackc/pgx/v4/stdlib"
+ "github.com/kr/pretty"
+ "github.com/stretchr/testify/suite"
+ "github.com/stripe/pg-schema-diff/internal/pgdump"
+ "github.com/stripe/pg-schema-diff/internal/pgengine"
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+
+ "github.com/stripe/pg-schema-diff/pkg/log"
+ "github.com/stripe/pg-schema-diff/pkg/tempdb"
+)
+
+type (
+ expectations struct {
+ planErrorIs error
+ planErrorContains string
+ // outputState should be the DDL required to reconstruct the expected output state of the database
+ //
+ // The outputState might differ from the newSchemaDDL due to options passed to the migrator. For example,
+ // the data packing option will cause the column ordering for new tables in the outputState to differ from
+ // the column ordering of those tables defined in newSchemaDDL
+ //
+ // If no outputState is specified, the newSchemaDDL will be used
+ outputState []string
+ }
+
+ acceptanceTestCase struct {
+ name string
+ oldSchemaDDL []string
+ newSchemaDDL []string
+
+ // expectedHazardTypes should contain all the unique migration hazard types that are expected to be within the
+ // generated plan
+ expectedHazardTypes []diff.MigrationHazardType
+
+ // vanillaExpectations refers to the expectations of the migration if no additional opts are used
+ vanillaExpectations expectations
+ // dataPackingExpectations refers to the expectations of the migration if table packing and ignore column order are used
+ dataPackingExpectations expectations
+ }
+
+ acceptanceTestSuite struct {
+ suite.Suite
+ pgEngine *pgengine.Engine
+ }
+)
+
+func (suite *acceptanceTestSuite) SetupSuite() {
+ engine, err := pgengine.StartEngine()
+ suite.Require().NoError(err)
+ suite.pgEngine = engine
+}
+
+func (suite *acceptanceTestSuite) TearDownSuite() {
+ suite.pgEngine.Close()
+}
+
+// Simulates migrating a database and uses pgdump to compare the actual state to the expected state
+func (suite *acceptanceTestSuite) runTestCases(acceptanceTestCases []acceptanceTestCase) {
+ for _, tc := range acceptanceTestCases {
+ suite.Run(tc.name, func() {
+ suite.Run("vanilla", func() {
+ suite.runSubtest(tc, tc.vanillaExpectations, nil)
+ })
+ suite.Run("with data packing (and ignoring column order)", func() {
+ suite.runSubtest(tc, tc.dataPackingExpectations, []diff.PlanOpt{
+ diff.WithDataPackNewTables(),
+ diff.WithIgnoreChangesToColOrder(),
+ diff.WithLogger(log.SimpleLogger()),
+ })
+ })
+ })
+ }
+}
+
+func (suite *acceptanceTestSuite) runSubtest(tc acceptanceTestCase, expects expectations, planOpts []diff.PlanOpt) {
+ // onDbInitQueries will be run on both the old database before the migration and the new database before pg_dump
+ onDbInitQueries := []string{
+ // Enable an extension to enforce that diffing works with extensions enabled
+ `CREATE EXTENSION amcheck;`,
+ }
+
+ // normalize the subtest
+ if expects.outputState == nil {
+ expects.outputState = tc.newSchemaDDL
+ }
+
+ // Apply old schema DDL to old DB
+ oldDb, err := suite.pgEngine.CreateDatabase()
+ suite.Require().NoError(err)
+ defer oldDb.DropDB()
+ // Apply the old schema
+ suite.Require().NoError(applyDDL(oldDb, append(onDbInitQueries, tc.oldSchemaDDL...)))
+
+ // Migrate the old DB
+ oldDBConnPool, err := sql.Open("pgx", oldDb.GetDSN())
+ suite.Require().NoError(err)
+ defer oldDBConnPool.Close()
+ oldDbConn, _ := oldDBConnPool.Conn(context.Background())
+ defer oldDbConn.Close()
+
+ tempDbFactory, err := tempdb.NewOnInstanceFactory(context.Background(), func(ctx context.Context, dbName string) (*sql.DB, error) {
+ return sql.Open("pgx", suite.pgEngine.GetPostgresDatabaseConnOpts().With("dbname", dbName).ToDSN())
+ })
+ suite.Require().NoError(err)
+ defer func(tempDbFactory tempdb.Factory) {
+ // It's important that this closes properly (the temp database is dropped),
+ // so assert it has no error for acceptance tests
+ suite.Require().NoError(tempDbFactory.Close())
+ }(tempDbFactory)
+
+ plan, err := diff.GeneratePlan(context.Background(), oldDbConn, tempDbFactory, tc.newSchemaDDL, planOpts...)
+
+ if expects.planErrorIs != nil || len(expects.planErrorContains) > 0 {
+ if expects.planErrorIs != nil {
+ suite.ErrorIs(err, expects.planErrorIs)
+ }
+ if len(expects.planErrorContains) > 0 {
+ suite.ErrorContains(err, expects.planErrorContains)
+ }
+ return
+ }
+ suite.Require().NoError(err)
+
+ suite.ElementsMatch(tc.expectedHazardTypes, getUniqueHazardTypesFromStatements(plan.Statements), prettySprintPlan(plan))
+
+ // Apply the plan
+ suite.Require().NoError(applyPlan(oldDb, plan), prettySprintPlan(plan))
+
+ // Make sure the pgdump after running the migration is the same as the
+ // pgdump from a database where we directly run the newSchemaDDL
+ oldDbDump, err := pgdump.GetDump(oldDb, pgdump.WithSchemaOnly())
+ suite.Require().NoError(err)
+
+ newDbDump := suite.directlyRunDDLAndGetDump(append(onDbInitQueries, expects.outputState...))
+ suite.Equal(newDbDump, oldDbDump, prettySprintPlan(plan))
+
+ // Make sure no diff is found if we try to regenerate a plan
+ plan, err = diff.GeneratePlan(context.Background(), oldDbConn, tempDbFactory, tc.newSchemaDDL, planOpts...)
+ suite.Require().NoError(err)
+ suite.Empty(plan.Statements, prettySprintPlan(plan))
+}
+
+func (suite *acceptanceTestSuite) directlyRunDDLAndGetDump(ddl []string) string {
+ newDb, err := suite.pgEngine.CreateDatabase()
+ suite.Require().NoError(err)
+ defer newDb.DropDB()
+ suite.Require().NoError(applyDDL(newDb, ddl))
+
+ newDbDump, err := pgdump.GetDump(newDb, pgdump.WithSchemaOnly())
+ suite.Require().NoError(err)
+ return newDbDump
+}
+
+func applyDDL(db *pgengine.DB, ddl []string) error {
+ conn, err := sql.Open("pgx", db.GetDSN())
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ for _, stmt := range ddl {
+ _, err := conn.Exec(stmt)
+ if err != nil {
+ return fmt.Errorf("DDL:\n: %w"+stmt, err)
+ }
+ }
+ return nil
+}
+
+func applyPlan(db *pgengine.DB, plan diff.Plan) error {
+ var ddl []string
+ for _, stmt := range plan.Statements {
+ ddl = append(ddl, stmt.ToSQL())
+ }
+ return applyDDL(db, ddl)
+}
+
+func getUniqueHazardTypesFromStatements(statements []diff.Statement) []diff.MigrationHazardType {
+ var seenHazardTypes = make(map[diff.MigrationHazardType]bool)
+ var hazardTypes []diff.MigrationHazardType
+ for _, stmt := range statements {
+ for _, hazard := range stmt.Hazards {
+ if _, hasHazard := seenHazardTypes[hazard.Type]; !hasHazard {
+ seenHazardTypes[hazard.Type] = true
+ hazardTypes = append(hazardTypes, hazard.Type)
+ }
+ }
+ }
+ return hazardTypes
+}
+
+func prettySprintPlan(plan diff.Plan) string {
+ return fmt.Sprintf("%# v", pretty.Formatter(plan.Statements))
+}
+
+func TestAcceptanceSuite(t *testing.T) {
+ suite.Run(t, new(acceptanceTestSuite))
+}
diff --git a/internal/migration_acceptance_tests/check_constraint_cases_test.go b/internal/migration_acceptance_tests/check_constraint_cases_test.go
new file mode 100644
index 0000000..599e2ba
--- /dev/null
+++ b/internal/migration_acceptance_tests/check_constraint_cases_test.go
@@ -0,0 +1,512 @@
+package migration_acceptance_tests
+
+import (
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+)
+
+var checkConstraintCases = []acceptanceTestCase{
+ {
+ name: "No-op",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT CHECK ( bar > id )
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT CHECK ( bar > id )
+ );
+ `,
+ },
+ },
+ {
+ name: "Add check constraint",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT CHECK ( bar > id )
+ );
+ `,
+ },
+ },
+ {
+ name: "Add check constraint with UDF dependency should error",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT CHECK ( add(bar, id) > 0 )
+ );
+ `,
+ },
+ vanillaExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ dataPackingExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ },
+ {
+ name: "Add check constraint with system function dependency should not error",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT CHECK ( to_timestamp(id) <= CURRENT_TIMESTAMP )
+ );
+ `,
+ },
+ },
+ {
+ name: "Add multiple check constraints",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT,
+ CHECK ( bar > id ), CHECK ( bar IS NOT NULL ), CHECK (bar > 0)
+ );
+ `,
+ },
+ },
+ {
+ name: "Add check constraints to new column",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255)
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT CHECK ( bar > id ), CHECK ( bar IS NOT NULL ), CHECK (bar > 0)
+ );
+ `,
+ },
+ },
+ {
+ name: "Add check constraint with quoted identifiers",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ "ID" INT PRIMARY KEY,
+ foo VARCHAR(255),
+ "Bar" BIGINT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ "ID" INT PRIMARY KEY,
+ foo VARCHAR(255),
+ "Bar" BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT "BAR_CHECK" CHECK ( "Bar" < "ID" );
+ `,
+ },
+ },
+ {
+ name: "Add no inherit check constraint",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NO INHERIT;
+ `,
+ },
+ },
+ {
+ name: "Add No-Inherit, Not-Valid check constraint",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NO INHERIT NOT VALID;
+ `,
+ },
+ },
+ {
+ name: "Drop check constraint",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT CHECK ( bar > id )
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ `,
+ },
+ },
+ {
+ name: "Drop check constraint with quoted identifiers",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ "ID" INT PRIMARY KEY,
+ foo VARCHAR(255),
+ "Bar" BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT "BAR_CHECK" CHECK ( "Bar" < "ID" );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ "ID" INT PRIMARY KEY,
+ foo VARCHAR(255),
+ "Bar" BIGINT
+ );
+ `,
+ },
+ },
+ {
+ name: "Drop column with check constraints",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT,
+ CHECK ( bar > id ), CHECK ( bar IS NOT NULL ), CHECK (bar > 0)
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255)
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeDeletesData},
+ },
+ {
+ name: "Drop check constraint with UDF dependency should error",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT CHECK ( add(bar, id) > 0 )
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT
+ );
+ `,
+ },
+ vanillaExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ dataPackingExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ },
+ {
+ name: "Drop check constraint with system function dependency should not error",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT CHECK ( to_timestamp(id) <= CURRENT_TIMESTAMP )
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT
+ );
+ `,
+ },
+ },
+ {
+ name: "Alter an invalid check constraint to be valid",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NOT VALID;
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id );
+ `,
+ },
+ },
+ {
+ name: "Alter a valid check constraint to be invalid",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NOT VALID;
+ `,
+ },
+ },
+ {
+ name: "Alter a No-Inherit check constraint to be Inheritable",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NO INHERIT;
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id );
+ `,
+ },
+ },
+ {
+ name: "Alter an Inheritable check constraint to be No-Inherit",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NO INHERIT;
+ `,
+ },
+ },
+ {
+ name: "Alter a check constraint expression",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT CHECK (bar > id)
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar BIGINT CHECK (bar < id)
+ );
+ `,
+ },
+ },
+ {
+ name: "Alter check constraint with UDF dependency should error",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT some_constraint CHECK ( add(bar, id) > 0 ) NOT VALID;
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT some_constraint CHECK ( add(bar, id) > 0 );
+ `,
+ },
+ vanillaExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ dataPackingExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ },
+ {
+ name: "Alter check constraint with system function dependency should not error",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT some_constraint CHECK ( to_timestamp(id) <= CURRENT_TIMESTAMP ) NOT VALID;
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar INT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT some_constraint CHECK ( to_timestamp(id) <= CURRENT_TIMESTAMP );
+ `,
+ },
+ },
+}
+
+func (suite *acceptanceTestSuite) TestCheckConstraintAcceptanceTestCases() {
+ suite.runTestCases(checkConstraintCases)
+}
diff --git a/internal/migration_acceptance_tests/column_cases_test.go b/internal/migration_acceptance_tests/column_cases_test.go
new file mode 100644
index 0000000..9ef60c9
--- /dev/null
+++ b/internal/migration_acceptance_tests/column_cases_test.go
@@ -0,0 +1,663 @@
+package migration_acceptance_tests
+
+import (
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+)
+
+var columnAcceptanceTestCases = []acceptanceTestCase{
+ {
+ name: "No-op",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255) COLLATE "C" DEFAULT '' NOT NULL,
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255) COLLATE "C" DEFAULT '' NOT NULL,
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL
+ );
+ `,
+ },
+ },
+ {
+ name: "Add one column with default",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ my_new_column VARCHAR(255) DEFAULT 'a'
+ );
+ `,
+ },
+ },
+ {
+ name: "Add one column with quoted names",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY,
+ "My_new_column" VARCHAR(255) DEFAULT 'a'
+ );
+ `,
+ },
+ },
+ {
+ name: "Add one column with nullability",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ my_new_column VARCHAR(255) NOT NULL
+ );
+ `,
+ },
+ },
+ {
+ name: "Add one column with all options",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ my_new_column VARCHAR(255) COLLATE "C" NOT NULL DEFAULT 'a'
+ );
+ `,
+ },
+ },
+ {
+ name: "Add one column and change ordering",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ my_new_column VARCHAR(255) NOT NULL DEFAULT 'a',
+ id INT PRIMARY KEY
+ );
+ `,
+ },
+ vanillaExpectations: expectations{
+ planErrorIs: diff.ErrColumnOrderingChanged,
+ },
+ dataPackingExpectations: expectations{
+ outputState: []string{`
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ my_new_column VARCHAR(255) NOT NULL DEFAULT 'a'
+ )
+ `},
+ },
+ },
+ {
+ name: "Delete one column",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeDeletesData,
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Delete one column with quoted name",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ "Id" INT PRIMARY KEY
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeDeletesData,
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Modify data type (varchar -> char)",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar CHAR NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Modify data type (varchar -> TEXT) with compatible default",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) DEFAULT 'some default' NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar TEXT DEFAULT 'some default' NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Modify data type and collation (varchar -> char)",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) COLLATE "C" NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar CHAR COLLATE "POSIX" NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Modify data type to incompatible (bytea -> char)",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar bytea NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar CHAR NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Modify collation (default -> non-default)",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) COLLATE "C" NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) COLLATE "POSIX" NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Modify collation (non-default -> non-default)",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) COLLATE "POSIX" NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Add Default",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255)
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) DEFAULT ''
+ );
+ `,
+ },
+ },
+ {
+ name: "Remove Default",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) DEFAULT ''
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255)
+ );
+ `,
+ },
+ },
+ {
+ name: "Change Default",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) DEFAULT 'Something else'
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) DEFAULT ''
+ );
+ `,
+ },
+ },
+ {
+ name: "Set NOT NULL",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255)
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ },
+ },
+ {
+ name: "Remove NOT NULL",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255)
+ );
+ `,
+ },
+ },
+ {
+ name: "Add default and change data type (new default is incompatible with old type)",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar INT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) DEFAULT 'SOMETHING'
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Change default and data type (new default is incompatible with old type)",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar INT DEFAULT 0
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar VARCHAR(255) DEFAULT 'SOMETHING'
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Change default and data type (old default is incompatible with new type)",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar TEXT DEFAULT 'SOMETHING'
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar INT DEFAULT 8
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ vanillaExpectations: expectations{planErrorContains: "validating migration plan"},
+ dataPackingExpectations: expectations{planErrorContains: "validating migration plan"},
+ },
+ {
+ name: "Change to not null",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar INT DEFAULT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar INT NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ },
+ },
+ {
+ name: "Change from NULL default to no default and NOT NULL",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar INT DEFAULT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar INT NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ },
+ },
+ {
+ name: "Change from NOT NULL to no NULL default",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar INT NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar INT DEFAULT NULL
+ );
+ `,
+ },
+ },
+ {
+ name: "Change data type and to nullable",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar INT NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar SMALLINT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Change data type and to not nullable",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar INT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar SMALLINT NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Change data type, nullability (NOT NULL), and default",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar TEXT DEFAULT 'some default'
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foobar CHAR NOT NULL DEFAULT 'A'
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Change data type and collation, nullability (NOT NULL), and default with quoted names",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY,
+ "Foobar" TEXT DEFAULT 'some default'
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY,
+ "Foobar" CHAR COLLATE "POSIX" NOT NULL DEFAULT 'A'
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Change BIGINT to TIMESTAMP, nullability (NOT NULL), and default with current_timestamp",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY,
+ some_time_col BIGINT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY,
+ some_time_col TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+}
+
+func (suite *acceptanceTestSuite) TestColumnAcceptanceTestCases() {
+ suite.runTestCases(columnAcceptanceTestCases)
+}
diff --git a/internal/migration_acceptance_tests/function_cases_test.go b/internal/migration_acceptance_tests/function_cases_test.go
new file mode 100644
index 0000000..86a683c
--- /dev/null
+++ b/internal/migration_acceptance_tests/function_cases_test.go
@@ -0,0 +1,480 @@
+package migration_acceptance_tests
+
+import (
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+)
+
+var functionAcceptanceTestCases = []acceptanceTestCase{
+ {
+ name: "No-op",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+ `,
+ },
+ },
+ {
+ name: "Create functions (with conflicting names)",
+ oldSchemaDDL: nil,
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+ CREATE FUNCTION add(a text, b text) RETURNS text
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(a, b);
+ `,
+ },
+ },
+ {
+ name: "Create functions with quoted names (with conflicting names)",
+ oldSchemaDDL: nil,
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION "some add"(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+ CREATE FUNCTION "some add"(a text, b text) RETURNS text
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(a, b);
+ `,
+ },
+ },
+ {
+ name: "Create non-sql function",
+ oldSchemaDDL: nil,
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION non_sql_func(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+ `},
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Create function with dependencies",
+ oldSchemaDDL: nil,
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + "increment func"(a);
+
+ -- function with conflicting name to ensure the deps specify param name
+ CREATE FUNCTION add(a text, b text) RETURNS text
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(a, b);
+ `},
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Drop functions (with conflicting names)",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+ CREATE FUNCTION add(a text, b text) RETURNS text
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(a, b);
+ `,
+ },
+ newSchemaDDL: nil,
+ },
+ {
+ name: "Drop functions with quoted names (with conflicting names)",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION "some add"(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+ CREATE FUNCTION "some add"(a text, b text) RETURNS text
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(a, b);
+ `,
+ },
+ newSchemaDDL: nil,
+ },
+ {
+ name: "Drop non-sql function",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION non_sql_func(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+ `},
+ newSchemaDDL: nil,
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Drop function with dependencies",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + "increment func"(a);
+
+ -- function with conflicting name to ensure the deps specify param name
+ CREATE FUNCTION add(a text, b text) RETURNS text
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(a, b);
+ `,
+ },
+ newSchemaDDL: nil,
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Alter functions (with conflicting names)",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+ CREATE FUNCTION add(a TEXT, b TEXT) RETURNS TEXT
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(a, b);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + a + b;
+ CREATE FUNCTION add(a TEXT, b TEXT) RETURNS TEXT
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(CONCAT(a, a), b);
+ `,
+ },
+ },
+ {
+ name: "Alter functions with quoted names (with conflicting names)",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION "some add"(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+ CREATE FUNCTION "some add"(a TEXT, b TEXT) RETURNS TEXT
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(a, b);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION "some add"(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + a + b;
+ CREATE FUNCTION "some add"(a TEXT, b TEXT) RETURNS TEXT
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(CONCAT(a, a), b);
+ `,
+ },
+ },
+ {
+ name: "Alter non-sql function",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION non_sql_func(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+ `},
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION non_sql_func(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 5;
+ END;
+ $$ LANGUAGE plpgsql;
+ `},
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Alter sql function to be non-sql function",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION some_func(i integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURN i + 5;
+ `},
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION some_func(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+ `},
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Alter non-sql function to be sql function (no dependency tracking error)",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION some_func(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+ `},
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION some_func(i integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURN i + 5;
+ `},
+ },
+ {
+ name: "Alter a function's dependencies",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + "increment func"(a);
+
+ -- function with conflicting name to ensure the deps specify param name
+ CREATE FUNCTION add(a text, b text) RETURNS text
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(a, b);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION "decrement func"(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + "decrement func"(a);
+
+ -- function with conflicting name to ensure the deps specify param name
+ CREATE FUNCTION add(a text, b text) RETURNS text
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN CONCAT(a, b);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Alter a dependent function",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + "increment func"(a);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION "increment func"(i integer) RETURNS int AS $$
+ BEGIN
+ RETURN i + 5;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + "increment func"(a);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Alter a function to no longer depend on a function and drop that function",
+ oldSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + "increment func"(a);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+}
+
+func (suite *acceptanceTestSuite) TestFunctionAcceptanceTestCases() {
+ suite.runTestCases(functionAcceptanceTestCases)
+}
diff --git a/internal/migration_acceptance_tests/index_cases_test.go b/internal/migration_acceptance_tests/index_cases_test.go
new file mode 100644
index 0000000..a52d682
--- /dev/null
+++ b/internal/migration_acceptance_tests/index_cases_test.go
@@ -0,0 +1,553 @@
+package migration_acceptance_tests
+
+import (
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+)
+
+var indexAcceptanceTestCases = []acceptanceTestCase{
+ {
+ name: "No-op",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT
+ );
+ CREATE INDEX some_idx ON foobar USING hash (foo);
+ CREATE UNIQUE INDEX some_other_idx ON foobar (bar DESC, fizz);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT
+ );
+ CREATE INDEX some_idx ON foobar USING hash (foo);
+ CREATE UNIQUE INDEX some_other_idx ON foobar (bar DESC, fizz);
+ `,
+ },
+ },
+ {
+ name: "Add a normal index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255)
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255)
+ );
+ CREATE INDEX some_idx ON foobar(id DESC, foo);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a hash index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255)
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255)
+ );
+ CREATE INDEX some_idx ON foobar USING hash (id);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a normal index with quoted names",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY,
+ "Foo" VARCHAR(255)
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY,
+ "Foo" VARCHAR(255)
+ );
+ CREATE INDEX "Some_idx" ON "Foobar"(id, "Foo");
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a unique index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255) NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255) NOT NULL
+ );
+ CREATE UNIQUE INDEX some_unique_idx ON foobar(foo);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a primary key",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a primary key on NOT NULL column",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT NOT NULL PRIMARY KEY
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a primary key when the index already exists",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT
+ );
+ CREATE UNIQUE INDEX foobar_primary_key ON foobar(id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT
+ );
+ CREATE UNIQUE INDEX foobar_primary_key ON foobar(id);
+ ALTER TABLE foobar ADD CONSTRAINT foobar_primary_key PRIMARY KEY USING INDEX foobar_primary_key;
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ },
+ },
+ {
+ name: "Add a primary key when the index already exists but has a name different to the constraint",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT
+ );
+ CREATE UNIQUE INDEX foobar_idx ON foobar(id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT
+ );
+ CREATE UNIQUE INDEX foobar_idx ON foobar(id);
+ -- This renames the index
+ ALTER TABLE foobar ADD CONSTRAINT foobar_primary_key PRIMARY KEY USING INDEX foobar_idx;
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a primary key when the index already exists but is not unique",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT
+ );
+ CREATE INDEX foobar_idx ON foobar(id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT
+ );
+ CREATE UNIQUE INDEX foobar_primary_key ON foobar(id);
+ ALTER TABLE foobar ADD CONSTRAINT foobar_primary_key PRIMARY KEY USING INDEX foobar_primary_key;
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Delete a normal index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255) NOT NULL
+ );
+ CREATE INDEX some_inx ON foobar(id, foo);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255)
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Delete a normal index with quoted names",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY,
+ "Foo" VARCHAR(255)
+ );
+ CREATE INDEX "Some_idx" ON "Foobar"(id, "Foo");
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY,
+ "Foo" VARCHAR(255) NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Delete a unique index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255) NOT NULL
+ );
+ CREATE UNIQUE INDEX some_unique_idx ON foobar(foo);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255) NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Delete a primary key",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Change an index (with a really long name) columns",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo TEXT NOT NULL,
+ bar BIGINT NOT NULL
+ );
+ CREATE UNIQUE INDEX some_idx_with_a_really_long_name_that_is_nearly_61_chars ON foobar(foo, bar)
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo TEXT NOT NULL,
+ bar BIGINT NOT NULL
+ );
+ CREATE UNIQUE INDEX some_idx_with_a_really_long_name_that_is_nearly_61_chars ON foobar(foo)
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Change an index type",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo TEXT NOT NULL,
+ bar BIGINT NOT NULL
+ );
+ CREATE INDEX some_idx ON foobar (foo)
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo TEXT NOT NULL,
+ bar BIGINT NOT NULL
+ );
+ CREATE INDEX some_idx ON foobar USING hash (foo)
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Change an index column ordering",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo TEXT NOT NULL,
+ bar BIGINT NOT NULL
+ );
+ CREATE INDEX some_idx ON foobar (foo, bar)
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo TEXT NOT NULL,
+ bar BIGINT NOT NULL
+ );
+ CREATE INDEX some_idx ON foobar (foo DESC, bar)
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Delete columns and associated index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo TEXT NOT NULL,
+ bar BIGINT NOT NULL
+ );
+ CREATE UNIQUE INDEX some_idx ON foobar(foo, bar);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeDeletesData,
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Switch primary key and make old key nullable",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT NOT NULL PRIMARY KEY,
+ foo TEXT NOT NULL
+ );
+ CREATE UNIQUE INDEX some_idx ON foobar(foo);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo INT NOT NULL PRIMARY KEY
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Switch primary key with quoted name",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ "Id" INT NOT NULL PRIMARY KEY,
+ foo TEXT NOT NULL
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ "Id" INT,
+ foo INT NOT NULL PRIMARY KEY
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Switch primary key when the original primary key constraint has a non-default name",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT NOT NULL,
+ foo TEXT NOT NULL
+ );
+ CREATE UNIQUE INDEX unique_idx ON foobar(id);
+ ALTER TABLE foobar ADD CONSTRAINT non_default_primary_key PRIMARY KEY USING INDEX unique_idx;
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo INT NOT NULL PRIMARY KEY
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Alter primary key columns (name stays same)",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT NOT NULL,
+ foo TEXT NOT NULL
+ );
+ CREATE UNIQUE INDEX unique_idx ON foobar(id);
+ ALTER TABLE foobar ADD CONSTRAINT non_default_primary_key PRIMARY KEY USING INDEX unique_idx;
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo INT NOT NULL
+ );
+ CREATE UNIQUE INDEX unique_idx ON foobar(id, foo);
+ ALTER TABLE foobar ADD CONSTRAINT non_default_primary_key PRIMARY KEY USING INDEX unique_idx;
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+}
+
+func (suite *acceptanceTestSuite) TestIndexAcceptanceTestCases() {
+ suite.runTestCases(indexAcceptanceTestCases)
+}
diff --git a/internal/migration_acceptance_tests/local_partition_index_cases_test.go b/internal/migration_acceptance_tests/local_partition_index_cases_test.go
new file mode 100644
index 0000000..cd96d2b
--- /dev/null
+++ b/internal/migration_acceptance_tests/local_partition_index_cases_test.go
@@ -0,0 +1,442 @@
+package migration_acceptance_tests
+
+import (
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+)
+
+var localPartitionIndexAcceptanceTestCases = []acceptanceTestCase{
+ {
+ name: "No-op",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT,
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX foobar_1_some_idx ON foobar_1 (foo);
+ CREATE UNIQUE INDEX foobar_2_some_unique_idx ON foobar_2 (foo);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT,
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX foobar_1_some_idx ON foobar_1 (foo);
+ CREATE UNIQUE INDEX foobar_2_some_unique_idx ON foobar_2 (foo);
+ `,
+ },
+ },
+ {
+ name: "Add local indexes",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz BYTEA
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz BYTEA
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX foobar_1_some_idx ON foobar_1(foo, id);
+ CREATE INDEX foobar_2_some_idx ON foobar_2(foo, bar);
+ CREATE INDEX foobar_3_some_idx ON foobar_3(foo, fizz);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a unique local index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE UNIQUE INDEX foobar_1_some_idx ON foobar_1 (foo);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add local primary keys",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz bytea
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz bytea
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar(
+ CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id)
+ ) FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar(
+ CONSTRAINT "foobar2_PRIMARY_KEY" PRIMARY KEY (foo, bar)
+ ) FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar(
+ CONSTRAINT "foobar3_PRIMARY_KEY" PRIMARY KEY (foo, fizz)
+ ) FOR VALUES IN ('foo_3');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Delete a local index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz BYTEA
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX foobar_1_some_idx ON foobar_1(foo, id);
+ CREATE INDEX foobar_2_some_idx ON foobar_2(foo, bar);
+ CREATE INDEX foobar_3_some_idx ON foobar_3(foo, fizz);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz BYTEA
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX foobar_1_some_idx ON foobar_1(foo, id);
+ CREATE INDEX foobar_3_some_idx ON foobar_3(foo, fizz);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Delete a unique local index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE UNIQUE INDEX some_unique_idx ON foobar_1(foo, id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Delete a primary key",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz bytea
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar(
+ CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id)
+ ) FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar(
+ CONSTRAINT "foobar2_PRIMARY_KEY" PRIMARY KEY (foo, bar)
+ ) FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar(
+ CONSTRAINT "foobar3_PRIMARY_KEY" PRIMARY KEY (foo, fizz)
+ ) FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz bytea
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar(
+ CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id)
+ ) FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar(
+ CONSTRAINT "foobar2_PRIMARY_KEY" PRIMARY KEY (foo, bar)
+ ) FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Change an index columns",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE UNIQUE INDEX some_unique_idx ON foobar_1(foo, id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE UNIQUE INDEX some_unique_idx ON foobar_1(id, foo);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Delete columns and associated index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE UNIQUE INDEX some_unique_idx ON foobar_1(foo, id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ vanillaExpectations: expectations{
+ planErrorIs: diff.ErrColumnOrderingChanged,
+ },
+ },
+ {
+ name: "Switch primary key",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz bytea
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar(
+ CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id)
+ ) FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar(
+ CONSTRAINT "foobar2_PRIMARY_KEY" PRIMARY KEY (foo, bar)
+ ) FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar(
+ CONSTRAINT "foobar3_PRIMARY_KEY" PRIMARY KEY (foo, fizz)
+ ) FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz bytea
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar(
+ CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id)
+ ) FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar(
+ CONSTRAINT "foobar2_PRIMARY_KEY" PRIMARY KEY (foo, fizz)
+ ) FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar(
+ CONSTRAINT "foobar3_PRIMARY_KEY" PRIMARY KEY (foo, fizz)
+ ) FOR VALUES IN ('foo_3');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a primary key when the index already exists",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255) NOT NULL,
+ bar TEXT NOT NULL,
+ fizz bytea
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE UNIQUE INDEX "foobar1_PRIMARY_KEY" ON foobar_1(foo, id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz bytea
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar(
+ CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id)
+ ) FOR VALUES IN ('foo_1');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ },
+ },
+ {
+ name: "Switch partitioned index to local index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz BYTEA
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+
+ CREATE INDEX foobar_some_idx ON foobar(foo, id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz BYTEA
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+
+ CREATE INDEX foobar_1_foo_id_idx ON foobar_1(foo, id);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+}
+
+func (suite *acceptanceTestSuite) TestLocalPartitionIndexAcceptanceTestCases() {
+ suite.runTestCases(localPartitionIndexAcceptanceTestCases)
+}
diff --git a/internal/migration_acceptance_tests/partitioned_index_cases_test.go b/internal/migration_acceptance_tests/partitioned_index_cases_test.go
new file mode 100644
index 0000000..e56438b
--- /dev/null
+++ b/internal/migration_acceptance_tests/partitioned_index_cases_test.go
@@ -0,0 +1,627 @@
+package migration_acceptance_tests
+
+import (
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+)
+
+var partitionedIndexAcceptanceTestCases = []acceptanceTestCase{
+ {
+ name: "No-op",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT,
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ CREATE INDEX some_idx ON foobar USING hash (foo);
+ CREATE UNIQUE INDEX some_other_idx ON foobar(foo DESC, fizz);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT,
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ CREATE INDEX some_idx ON foobar USING hash (foo);
+ CREATE UNIQUE INDEX some_other_idx ON foobar(foo DESC, fizz);
+ `,
+ },
+ },
+ {
+ name: "Add a normal partitioned index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX some_idx ON foobar(id DESC, foo);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a hash index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX some_idx ON foobar USING hash (foo);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a unique partitioned index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ CREATE UNIQUE INDEX some_unique_idx ON foobar(foo, id);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a normal partitioned index with quotes names",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT,
+ "Foo" VARCHAR(255)
+ ) PARTITION BY LIST ("Foo");
+ CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar" FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT,
+ "Foo" VARCHAR(255)
+ ) PARTITION BY LIST ("Foo");
+ CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar" FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3');
+
+ CREATE INDEX "SOME_IDX" ON "Foobar"(id, "Foo");
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a primary key",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeAcquiresShareLock,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a partitioned index that is used by a local primary key",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE UNIQUE INDEX some_idx ON ONLY foobar(foo, id);
+ CREATE UNIQUE INDEX foobar_1_pkey ON foobar_1(foo, id);
+ ALTER TABLE foobar_1 ADD CONSTRAINT foobar_1_pkey PRIMARY KEY USING INDEX foobar_1_pkey;
+ ALTER INDEX some_idx ATTACH PARTITION foobar_1_pkey;
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a primary key with quoted names",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ "Id" INT,
+ "FOO" VARCHAR(255)
+ ) PARTITION BY LIST ("FOO");
+ CREATE TABLE foobar_1 PARTITION OF "Foobar" FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ "Id" INT,
+ "FOO" VARCHAR(255)
+ ) PARTITION BY LIST ("FOO");
+ CREATE TABLE foobar_1 PARTITION OF "Foobar" FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3');
+ ALTER TABLE "Foobar" ADD CONSTRAINT "FOOBAR_PK" PRIMARY KEY("FOO", "Id")
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexBuild,
+ diff.MigrationHazardTypeAcquiresShareLock,
+ },
+ },
+ {
+ name: "Add a partitioned primary key when the local index already exists",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE UNIQUE INDEX foobar_1_unique_idx ON foobar_1(foo, id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ ALTER TABLE foobar ADD CONSTRAINT foobar_pkey PRIMARY KEY (foo, id);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeAcquiresShareLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Add a unique index when the local index already exists",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE UNIQUE INDEX foobar_1_foo_id_idx ON foobar_1(foo, id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ ALTER TABLE foobar ADD CONSTRAINT foobar_pkey PRIMARY KEY (foo, id);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE UNIQUE INDEX foobar_unique ON foobar(foo, id);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeAcquiresShareLock,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Delete a normal partitioned index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ CREATE INDEX some_idx ON foobar(foo, id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Delete a unique partitioned index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ CREATE UNIQUE INDEX some_unique_idx ON foobar(foo, id);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Delete a primary key",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ },
+ {
+ name: "Change an index columns",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar INT
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX some_idx ON foobar(id, foo);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar INT
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX some_idx ON foobar(foo, bar);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Change an index type",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar INT
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX some_idx ON foobar(foo);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar INT
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX some_idx ON foobar USING hash (foo);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Change an index column ordering",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar INT
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX some_idx ON foobar (foo, bar);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar INT
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX some_idx ON foobar (bar, foo);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Delete columns and associated index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ CREATE INDEX some_idx ON foobar(id, foo);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ foo VARCHAR(255)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeDeletesData,
+ diff.MigrationHazardTypeIndexDropped,
+ },
+ vanillaExpectations: expectations{
+ planErrorIs: diff.ErrColumnOrderingChanged,
+ },
+ },
+ {
+ name: "Switch primary key",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ PRIMARY KEY (foo)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresShareLock,
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Attach an unnattached, invalid index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT,
+ foo VARCHAR(255),
+ PRIMARY KEY (foo)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF "Foobar" FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2');
+ CREATE TABLE "Foobar_3" PARTITION OF "Foobar" FOR VALUES IN ('foo_3');
+
+ CREATE INDEX "Partitioned_Idx" ON ONLY "Foobar"(foo);
+
+ CREATE INDEX "foobar_1_part" ON foobar_1(foo);
+ ALTER INDEX "Partitioned_Idx" ATTACH PARTITION "foobar_1_part";
+
+ CREATE INDEX "foobar_2_part" ON foobar_2(foo);
+ ALTER INDEX "Partitioned_Idx" ATTACH PARTITION "foobar_2_part";
+
+ CREATE INDEX "Foobar_3_Part" ON "Foobar_3"(foo);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT,
+ foo VARCHAR(255),
+ PRIMARY KEY (foo)
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE foobar_1 PARTITION OF "Foobar" FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2');
+ CREATE TABLE "Foobar_3" PARTITION OF "Foobar" FOR VALUES IN ('foo_3');
+
+ CREATE INDEX "Partitioned_Idx" ON ONLY "Foobar"(foo);
+
+ CREATE INDEX "foobar_1_part" ON foobar_1(foo);
+ ALTER INDEX "Partitioned_Idx" ATTACH PARTITION "foobar_1_part";
+
+ CREATE INDEX "foobar_2_part" ON foobar_2(foo);
+ ALTER INDEX "Partitioned_Idx" ATTACH PARTITION "foobar_2_part";
+
+ CREATE INDEX "Foobar_3_Part" ON "Foobar_3"(foo);
+ ALTER INDEX "Partitioned_Idx" ATTACH PARTITION "Foobar_3_Part";
+ `,
+ },
+ },
+}
+
+func (suite *acceptanceTestSuite) TestPartitionedIndexAcceptanceTestCases() {
+ suite.runTestCases(partitionedIndexAcceptanceTestCases)
+}
diff --git a/internal/migration_acceptance_tests/partitioned_table_cases_test.go b/internal/migration_acceptance_tests/partitioned_table_cases_test.go
new file mode 100644
index 0000000..d764db0
--- /dev/null
+++ b/internal/migration_acceptance_tests/partitioned_table_cases_test.go
@@ -0,0 +1,810 @@
+package migration_acceptance_tests
+
+import (
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+)
+
+var partitionedTableAcceptanceTestCases = []acceptanceTestCase{
+ {
+ name: "No-op",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT COLLATE "POSIX",
+ fizz INT,
+ CHECK ( fizz > 0 ),
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, fizz);
+
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON foobar_1(foo);
+ CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT COLLATE "POSIX",
+ fizz INT,
+ CHECK ( fizz > 0 ),
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, fizz);
+
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON foobar_1(foo);
+ CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo);
+ `,
+ },
+ },
+ {
+ name: "Create partitioned table with shared primary key",
+ oldSchemaDDL: nil,
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT COLLATE "POSIX" NOT NULL DEFAULT 'some default',
+ fizz INT,
+ CHECK ( fizz > 0 ),
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar"(
+ foo NOT NULL,
+ bar NOT NULL
+ ) FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3');
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX foobar_unique_idx ON "Foobar"(foo, fizz);
+
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON "FOOBAR_1"(foo);
+ CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo);
+ `,
+ },
+ dataPackingExpectations: expectations{
+ outputState: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT COLLATE "POSIX" NOT NULL DEFAULT 'some default',
+ CHECK ( fizz > 0 ),
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar"(
+ foo NOT NULL,
+ bar NOT NULL
+ ) FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3');
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX foobar_unique_idx ON "Foobar"(foo, fizz);
+
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON "FOOBAR_1"(foo);
+ CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo);
+ `,
+ },
+ },
+ },
+ {
+ name: "Create partitioned table with local primary keys",
+ oldSchemaDDL: nil,
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT,
+ CHECK ( fizz > 0 )
+ ) PARTITION BY LIST (foo);
+ CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar"(
+ foo NOT NULL,
+ bar NOT NULL,
+ PRIMARY KEY (foo, id)
+ ) FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF "Foobar"(
+ PRIMARY KEY (foo, bar)
+ ) FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF "Foobar"(
+ PRIMARY KEY (foo, fizz)
+ ) FOR VALUES IN ('foo_3');
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX foobar_unique_idx ON "Foobar"(foo, fizz);
+
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON "FOOBAR_1"(foo);
+ CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo);
+ `,
+ },
+ dataPackingExpectations: expectations{
+ outputState: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ CHECK ( fizz > 0 )
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar"(
+ foo NOT NULL,
+ bar NOT NULL,
+ PRIMARY KEY (foo, id)
+ ) FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF "Foobar"(
+ PRIMARY KEY (foo, bar)
+ ) FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF "Foobar"(
+ PRIMARY KEY (foo, fizz)
+ ) FOR VALUES IN ('foo_3');
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX foobar_unique_idx ON "Foobar"(foo, fizz);
+
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON "FOOBAR_1"(foo);
+ CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo);
+ `,
+ },
+ },
+ },
+ {
+ name: "Drop table",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT,
+ CHECK ( fizz > 0 ),
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar"(
+ foo NOT NULL,
+ bar NOT NULL
+ ) FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3');
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX foobar_unique_idx ON "Foobar"(foo, fizz);
+
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON "FOOBAR_1"(foo);
+ CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ },
+ newSchemaDDL: nil,
+ },
+ {
+ name: "Alter table: New primary key, change column types, delete unique partitioned index index, new partitioned index, delete local index, add local index, validate check constraint",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT COLLATE "C",
+ fizz INT,
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+
+ ALTER TABLE foobar ADD CONSTRAINT some_check_constraint CHECK ( fizz > 0 ) NOT VALID;
+
+ CREATE TABLE foobar_1 PARTITION OF foobar(
+ foo NOT NULL,
+ bar NOT NULL
+ ) FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ -- partitioned indexes
+ CREATE INDEX foobar_some_idx ON foobar(foo, bar);
+
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON foobar_1(foo);
+ CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id TEXT,
+ foo VARCHAR(255),
+ bar TEXT COLLATE "POSIX",
+ fizz INT,
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+ ALTER TABLE foobar ADD CONSTRAINT some_check_constraint CHECK ( fizz > 0 );
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ CREATE TABLE foobar_2 PARTITION OF foobar(
+ bar NOT NULL
+ ) FOR VALUES IN ('foo_2');
+ CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3');
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, fizz);
+
+ -- local indexes
+ CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo);
+ CREATE INDEX foobar_3_local_idx ON foobar_3(foo, bar);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Changing partition key def errors",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT
+ ) PARTITION BY LIST (bar);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ `,
+ },
+ vanillaExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ dataPackingExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ },
+ {
+ name: "Unpartitioned to partitioned",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ version INT,
+ foo VARCHAR(255),
+ bar TEXT
+ );
+
+ CREATE INDEX some_idx on foobar(id);
+
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foobar
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ version INT,
+ foo VARCHAR(255),
+ bar TEXT
+ ) PARTITION BY LIST (foo);
+
+ CREATE INDEX some_idx on foobar(id);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foobar
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ },
+ },
+ {
+ name: "Unpartitioned to partitioned and child tables already exist",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ version INT,
+ foo VARCHAR(255),
+ bar TEXT
+ );
+
+ CREATE INDEX some_idx on foobar(id);
+
+ CREATE TABLE foobar_1(
+ id INT,
+ version INT,
+ foo VARCHAR(255),
+ bar TEXT
+ );
+
+ CREATE INDEX foobar_1_id_idx on foobar(id);
+
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foobar
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foobar_1
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ version INT,
+ foo VARCHAR(255),
+ bar TEXT
+ ) PARTITION BY LIST (foo);
+
+ CREATE INDEX some_idx on foobar(id);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foobar
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ },
+ },
+ {
+ name: "Partitioned to unpartitioned",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ version INT,
+ foo VARCHAR(255),
+ bar TEXT
+ ) PARTITION BY LIST (foo);
+
+ CREATE INDEX some_idx on foobar(id);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foobar
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ version INT,
+ foo VARCHAR(255),
+ bar TEXT
+ );
+
+ CREATE INDEX some_idx on foobar(id);
+
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foobar
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ },
+ },
+ {
+ name: "Partitioned to unpartitioned and child tables still exist",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ version INT,
+ foo VARCHAR(255),
+ bar TEXT
+ ) PARTITION BY LIST (foo);
+
+ CREATE INDEX some_idx on foobar(id);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foobar
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ version INT,
+ foo VARCHAR(255),
+ bar TEXT
+ );
+
+ CREATE INDEX some_idx on foobar(id);
+
+ CREATE TABLE foobar_1(
+ id INT,
+ version INT,
+ foo VARCHAR(255),
+ bar TEXT
+ );
+
+ CREATE INDEX foobar_1_id_idx on foobar(id);
+
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foobar
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foobar_1
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ },
+ },
+ {
+ name: "Adding a partition",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ CHECK ( fizz > 0 ),
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX some_partitioned_idx ON foobar(foo, bar);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ CHECK ( fizz > 0 ),
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX some_partitioned_idx ON foobar(foo, bar);
+ `,
+ },
+ },
+ {
+ name: "Adding a partition with local primary key that can back the unique index",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ CHECK ( fizz > 0 )
+ ) PARTITION BY LIST (foo);
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX some_partitioned_idx ON foobar(foo, bar);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ CHECK ( fizz > 0 )
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar(
+ PRIMARY KEY (foo, bar)
+ ) FOR VALUES IN ('foo_1');
+
+ -- partitioned indexes
+ CREATE UNIQUE INDEX some_partitioned_idx ON foobar(foo, bar);
+ `,
+ },
+ },
+ {
+ name: "Deleting a partitioning errors",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT
+ ) PARTITION BY LIST (foo);
+ `,
+ },
+ vanillaExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ dataPackingExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ },
+ {
+ name: "Altering a partition's 'FOR VALUES' errors",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1');
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255),
+ bar TEXT,
+ fizz INT
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_2');
+ `,
+ },
+ vanillaExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ dataPackingExpectations: expectations{
+ planErrorIs: diff.ErrNotImplemented,
+ },
+ },
+ {
+ name: "Re-creating base table causes partitions to be re-created",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar(
+ PRIMARY KEY (foo, fizz)
+ ) FOR VALUES IN ('foo_1');
+
+ -- partitioned indexes
+ CREATE INDEX some_partitioned_idx ON foobar(foo, bar);
+ -- local indexes
+ CREATE INDEX some_local_idx ON foobar_1(foo, bar);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar_new(
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 PARTITION OF foobar_new(
+ PRIMARY KEY (foo, fizz)
+ ) FOR VALUES IN ('foo_1');
+
+ -- partitioned indexes
+ CREATE INDEX some_partitioned_idx ON foobar_new(foo, bar);
+ -- local indexes
+ CREATE INDEX some_local_idx ON foobar_1(foo, bar);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ },
+ },
+ {
+ name: "Can handle scenario where partition is not attached",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 (
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT
+ );
+
+ -- partitioned indexes
+ CREATE INDEX some_partitioned_idx ON foobar(foo, bar);
+ -- local indexes
+ CREATE INDEX some_local_idx ON foobar_1(foo, bar);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT
+ ) PARTITION BY LIST (foo);
+
+ CREATE TABLE foobar_1 (
+ id INT,
+ fizz INT,
+ foo VARCHAR(255),
+ bar TEXT
+ );
+ ALTER TABLE foobar ATTACH PARTITION foobar_1 FOR VALUES IN ('foo_1');
+
+ -- partitioned indexes
+ CREATE INDEX some_partitioned_idx ON foobar(foo, bar);
+ -- local indexes
+ CREATE INDEX some_local_idx ON foobar_1(foo, bar);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ },
+ },
+}
+
+func (suite *acceptanceTestSuite) TestPartitionedTableAcceptanceTestCases() {
+ suite.runTestCases(partitionedTableAcceptanceTestCases)
+}
diff --git a/internal/migration_acceptance_tests/schema_cases_test.go b/internal/migration_acceptance_tests/schema_cases_test.go
new file mode 100644
index 0000000..24e0b39
--- /dev/null
+++ b/internal/migration_acceptance_tests/schema_cases_test.go
@@ -0,0 +1,446 @@
+package migration_acceptance_tests
+
+import (
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+)
+
+// These are tests for "public" schema" alterations (full migrations)
+var schemaAcceptanceTests = []acceptanceTestCase{
+ {
+ name: "No-op",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE fizz(
+ );
+
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION increment(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + increment(a);
+
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0),
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL,
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST(foo);
+
+ CREATE TABLE foobar_1 PARTITION of foobar(
+ fizz NOT NULL
+ ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2');
+
+ -- partitioned indexes
+ CREATE INDEX foobar_normal_idx ON foobar(foo DESC, fizz);
+ CREATE INDEX foobar_hash_idx ON foobar USING hash (foo);
+ CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar);
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz);
+
+ CREATE table bar(
+ id VARCHAR(255) PRIMARY KEY,
+ foo INT DEFAULT 0 CHECK (foo > 0 AND foo > bar),
+ bar DOUBLE PRECISION NOT NULL DEFAULT 8.8,
+ fizz timestamptz DEFAULT CURRENT_TIMESTAMP,
+ buzz REAL NOT NULL CHECK (buzz IS NOT NULL)
+ );
+ CREATE INDEX bar_normal_idx ON bar(bar);
+ CREATE INDEX bar_another_normal_id ON bar(bar, fizz);
+ CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE fizz(
+ );
+
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION increment(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + increment(a);
+
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0),
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL,
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST(foo);
+
+ CREATE TABLE foobar_1 PARTITION of foobar(
+ fizz NOT NULL
+ ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2');
+
+ -- partitioned indexes
+ CREATE INDEX foobar_normal_idx ON foobar(foo DESC, fizz);
+ CREATE INDEX foobar_hash_idx ON foobar USING hash (foo);
+ CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar);
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz);
+
+ CREATE table bar(
+ id VARCHAR(255) PRIMARY KEY,
+ foo INT DEFAULT 0 CHECK (foo > 0 AND foo > bar),
+ bar DOUBLE PRECISION NOT NULL DEFAULT 8.8,
+ fizz timestamptz DEFAULT CURRENT_TIMESTAMP,
+ buzz REAL NOT NULL CHECK (buzz IS NOT NULL)
+ );
+ CREATE INDEX bar_normal_idx ON bar(bar);
+ CREATE INDEX bar_another_normal_id ON bar(bar, fizz);
+ CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz);
+ `,
+ },
+ },
+ {
+ name: "Drop table, Add Table, Drop Funcs, Add Funcs, Drop Triggers, Add Triggers",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE fizz(
+ );
+
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION increment(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + increment(a);
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = increment(OLD.version);
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ foo VARCHAR(255) DEFAULT 'some default' NOT NULL,
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL
+ );
+ CREATE INDEX foobar_normal_idx ON foobar USING hash (fizz);
+ CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar DESC);
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON foobar
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+
+ CREATE table bar(
+ id VARCHAR(255) PRIMARY KEY,
+ foo INT DEFAULT 0,
+ bar DOUBLE PRECISION NOT NULL DEFAULT 8.8,
+ fizz timestamptz DEFAULT CURRENT_TIMESTAMP,
+ buzz REAL NOT NULL
+ );
+ ALTER TABLE bar ADD CONSTRAINT "FOO_CHECK" CHECK (foo < bar) NOT VALID;
+ CREATE INDEX bar_normal_idx ON bar(bar);
+ CREATE INDEX bar_another_normal_id ON bar(bar DESC, fizz DESC);
+ CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE fizz(
+ );
+
+ CREATE FUNCTION "new add"(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b + a;
+
+ CREATE FUNCTION "new increment"(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 2;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION "new function with dependencies"(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN "new add"(a, b) + "new increment"(a);
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = "new increment"(OLD.version);
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TABLE "New_table"(
+ id INT PRIMARY KEY,
+ new_foo VARCHAR(255) DEFAULT '' NOT NULL CHECK ( new_foo IS NOT NULL),
+ new_bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ version INT NOT NULL DEFAULT 0
+ );
+ ALTER TABLE "New_table" ADD CONSTRAINT "new_bar_check" CHECK ( new_bar < CURRENT_TIMESTAMP - interval '1 month' ) NO INHERIT NOT VALID;
+ CREATE UNIQUE INDEX foobar_unique_idx ON "New_table"(new_foo, new_bar);
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "New_table"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+
+ CREATE TABLE bar(
+ id VARCHAR(255) PRIMARY KEY,
+ foo INT,
+ bar DOUBLE PRECISION NOT NULL DEFAULT 8.8,
+ fizz timestamptz DEFAULT CURRENT_TIMESTAMP,
+ buzz REAL NOT NULL
+ );
+ ALTER TABLE bar ADD CONSTRAINT "FOO_CHECK" CHECK ( foo < bar );
+ CREATE INDEX bar_normal_idx ON bar(bar);
+ CREATE INDEX bar_another_normal_id ON bar(bar DESC, fizz DESC);
+ CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz);
+
+ CREATE FUNCTION check_content() RETURNS TRIGGER AS $$
+ BEGIN
+ IF LENGTH(NEW.id) == 0 THEN
+ RAISE EXCEPTION 'content is empty';
+ END IF;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_check_trigger
+ BEFORE UPDATE ON bar
+ FOR EACH ROW
+ EXECUTE FUNCTION check_content();
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ diff.MigrationHazardTypeHasUntrackableDependencies,
+ },
+ dataPackingExpectations: expectations{
+ outputState: []string{
+ `
+ CREATE TABLE fizz(
+ );
+
+ CREATE FUNCTION "new add"(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b + a;
+
+ CREATE FUNCTION "new increment"(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 2;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION "new function with dependencies"(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN "new add"(a, b) + "new increment"(a);
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = "new increment"(OLD.version);
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TABLE "New_table"(
+ new_bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ id INT PRIMARY KEY,
+ version INT NOT NULL DEFAULT 0,
+ new_foo VARCHAR(255) DEFAULT '' NOT NULL CHECK (new_foo IS NOT NULL)
+ );
+ ALTER TABLE "New_table" ADD CONSTRAINT "new_bar_check" CHECK ( new_bar < CURRENT_TIMESTAMP - interval '1 month' ) NO INHERIT NOT VALID;
+ CREATE UNIQUE INDEX foobar_unique_idx ON "New_table"(new_foo, new_bar);
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "New_table"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+
+ CREATE TABLE bar(
+ id VARCHAR(255) PRIMARY KEY,
+ foo INT,
+ bar DOUBLE PRECISION NOT NULL DEFAULT 8.8,
+ fizz timestamptz DEFAULT CURRENT_TIMESTAMP,
+ buzz REAL NOT NULL
+ );
+ ALTER TABLE bar ADD CONSTRAINT "FOO_CHECK" CHECK ( foo < bar );
+ CREATE INDEX bar_normal_idx ON bar(bar);
+ CREATE INDEX bar_another_normal_id ON bar(bar DESC, fizz DESC);
+ CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz);
+
+ CREATE FUNCTION check_content() RETURNS TRIGGER AS $$
+ BEGIN
+ IF LENGTH(NEW.id) == 0 THEN
+ RAISE EXCEPTION 'content is empty';
+ END IF;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_check_trigger
+ BEFORE UPDATE ON bar
+ FOR EACH ROW
+ EXECUTE FUNCTION check_content();
+ `,
+ },
+ },
+ },
+ {
+ name: "Drop partitioned table, Add partitioned table with local keys",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0),
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL,
+ PRIMARY KEY (foo, id)
+ ) PARTITION BY LIST(foo);
+
+ CREATE TABLE foobar_1 PARTITION of foobar(
+ fizz NOT NULL
+ ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2');
+
+ -- partitioned indexes
+ CREATE INDEX foobar_normal_idx ON foobar(foo, fizz);
+ CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar);
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz);
+
+ CREATE table bar(
+ id VARCHAR(255) PRIMARY KEY,
+ foo INT DEFAULT 0 CHECK (foo > 0 AND foo > bar),
+ bar DOUBLE PRECISION NOT NULL DEFAULT 8.8,
+ fizz timestamptz DEFAULT CURRENT_TIMESTAMP,
+ buzz REAL NOT NULL CHECK (buzz IS NOT NULL)
+ );
+ CREATE INDEX bar_normal_idx ON bar(bar);
+ CREATE INDEX bar_another_normal_id ON bar(bar, fizz);
+ CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz);
+
+ CREATE TABLE fizz(
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE new_foobar(
+ id INT,
+ foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0),
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL
+ ) PARTITION BY LIST(foo);
+
+ CREATE TABLE foobar_1 PARTITION of new_foobar(
+ fizz NOT NULL,
+ PRIMARY KEY (foo, bar)
+ ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2');
+
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz);
+ -- partitioned indexes
+ CREATE INDEX foobar_normal_idx ON new_foobar(foo, fizz);
+ CREATE UNIQUE INDEX foobar_unique_idx ON new_foobar(foo, bar);
+
+ CREATE table bar(
+ id VARCHAR(255) PRIMARY KEY,
+ foo INT DEFAULT 0 CHECK (foo > 0 AND foo > bar),
+ bar DOUBLE PRECISION NOT NULL DEFAULT 8.8,
+ fizz timestamptz DEFAULT CURRENT_TIMESTAMP,
+ buzz REAL NOT NULL CHECK (buzz IS NOT NULL)
+ );
+ CREATE INDEX bar_normal_idx ON bar(bar);
+ CREATE INDEX bar_another_normal_id ON bar(bar, fizz);
+ CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz);
+
+ CREATE TABLE fizz(
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ },
+ dataPackingExpectations: expectations{
+ outputState: []string{
+ `
+ CREATE TABLE new_foobar(
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ id INT,
+ fizz BOOLEAN NOT NULL,
+ foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0)
+ ) PARTITION BY LIST(foo);
+
+ CREATE TABLE foobar_1 PARTITION of new_foobar(
+ fizz NOT NULL,
+ PRIMARY KEY (foo, bar)
+ ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2');
+
+ -- local indexes
+ CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz);
+ -- partitioned indexes
+ CREATE INDEX foobar_normal_idx ON new_foobar(foo, fizz);
+ CREATE UNIQUE INDEX foobar_unique_idx ON new_foobar(foo, bar);
+
+ CREATE table bar(
+ id VARCHAR(255) PRIMARY KEY,
+ foo INT DEFAULT 0 CHECK (foo > 0 AND foo > bar),
+ bar DOUBLE PRECISION NOT NULL DEFAULT 8.8,
+ fizz timestamptz DEFAULT CURRENT_TIMESTAMP,
+ buzz REAL NOT NULL CHECK (buzz IS NOT NULL)
+ );
+ CREATE INDEX bar_normal_idx ON bar(bar);
+ CREATE INDEX bar_another_normal_id ON bar(bar, fizz);
+ CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz);
+
+ CREATE TABLE fizz(
+ );
+ `,
+ },
+ },
+ },
+}
+
+func (suite *acceptanceTestSuite) TestSchemaAcceptanceTestCases() {
+ suite.runTestCases(schemaAcceptanceTests)
+}
diff --git a/internal/migration_acceptance_tests/table_cases_test.go b/internal/migration_acceptance_tests/table_cases_test.go
new file mode 100644
index 0000000..0c5c9a1
--- /dev/null
+++ b/internal/migration_acceptance_tests/table_cases_test.go
@@ -0,0 +1,273 @@
+package migration_acceptance_tests
+
+import (
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+)
+
+var tableAcceptanceTestCases = []acceptanceTestCase{
+ {
+ name: "No-op",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz),
+ foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL,
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL,
+ buzz REAL CHECK (buzz IS NOT NULL)
+ );
+ CREATE INDEX normal_idx ON foobar(fizz);
+ CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY CHECK (id > 0),
+ foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL,
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL,
+ buzz REAL CHECK (buzz IS NOT NULL)
+ );
+ CREATE INDEX normal_idx ON foobar(fizz);
+ CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar);
+ `,
+ },
+ },
+ {
+ name: "Create table",
+ oldSchemaDDL: nil,
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz),
+ foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL,
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL,
+ buzz REAL CHECK (buzz IS NOT NULL)
+ );
+ CREATE INDEX normal_idx ON foobar(fizz);
+ CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar);
+ `,
+ },
+ dataPackingExpectations: expectations{
+ outputState: []string{`
+ CREATE TABLE foobar(
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz),
+ buzz REAL CHECK (buzz IS NOT NULL),
+ fizz BOOLEAN NOT NULL,
+ foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL
+ );
+ CREATE INDEX normal_idx ON foobar(fizz);
+ CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar);
+ `},
+ },
+ },
+ {
+ name: "Create table with quoted names",
+ oldSchemaDDL: nil,
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY,
+ "Foo" VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL CHECK (LENGTH("Foo") > 0),
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL
+ );
+ CREATE INDEX normal_idx ON "Foobar" USING hash (fizz);
+ CREATE UNIQUE INDEX unique_idx ON "Foobar"("Foo" DESC, bar);
+ `,
+ },
+ dataPackingExpectations: expectations{
+ outputState: []string{
+ `
+ CREATE TABLE "Foobar"(
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ id INT PRIMARY KEY,
+ fizz BOOLEAN NOT NULL,
+ "Foo" VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL CHECK (LENGTH("Foo") > 0)
+ );
+ CREATE INDEX normal_idx ON "Foobar" USING hash (fizz);
+ CREATE UNIQUE INDEX unique_idx ON "Foobar"("Foo" DESC, bar);
+ `,
+ },
+ },
+ },
+ {
+ name: "Drop table",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz),
+ foo VARCHAR(255) COLLATE "C" DEFAULT '' NOT NULL,
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL,
+ buzz REAL CHECK (buzz IS NOT NULL)
+ );
+ CREATE INDEX normal_idx ON foobar(fizz);
+ CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ },
+ newSchemaDDL: nil,
+ },
+ {
+ name: "Drop a table with quoted names",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "Foobar"(
+ id INT PRIMARY KEY,
+ "Foo" VARCHAR(255) COLLATE "C" DEFAULT '' NOT NULL CHECK (LENGTH("Foo") > 0),
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL
+ );
+ CREATE INDEX normal_idx ON "Foobar"(fizz);
+ CREATE UNIQUE INDEX unique_idx ON "Foobar"("Foo", "bar");
+ `,
+ },
+ newSchemaDDL: nil,
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ },
+ },
+ {
+ name: "Alter table: New primary key, change column types, delete unique index, new index, validate check constraint",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz),
+ foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL,
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL,
+ buzz REAL,
+ fizzbuzz TEXT
+ );
+ ALTER TABLE foobar ADD CONSTRAINT buzz_check CHECK (buzz IS NOT NULL) NOT VALID;
+ CREATE INDEX normal_idx ON foobar(fizz);
+ CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT CHECK (id > 0),
+ foo CHAR COLLATE "C" DEFAULT '5' NOT NULL PRIMARY KEY,
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL ,
+ fizz BOOLEAN NOT NULL,
+ buzz REAL,
+ fizzbuzz TEXT COLLATE "POSIX"
+ );
+ ALTER TABLE foobar ADD CONSTRAINT buzz_check CHECK (buzz IS NOT NULL);
+ CREATE INDEX normal_idx ON foobar(fizz);
+ CREATE INDEX other_idx ON foobar(bar);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+ {
+ name: "Alter table: New column, new primary key, alter column to nullable, drop column, drop index, drop check constraints",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz),
+ foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL,
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL,
+ buzz REAL CHECK (buzz IS NOT NULL)
+ );
+ CREATE INDEX normal_idx ON foobar(fizz);
+ CREATE UNIQUE INDEX unique_idx ON foobar(foo DESC, bar);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT,
+ foo CHAR DEFAULT '5',
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
+ new_fizz DECIMAL(65, 10) DEFAULT 5.25 NOT NULL PRIMARY KEY
+ );
+ CREATE INDEX other_idx ON foobar(bar);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ diff.MigrationHazardTypeDeletesData,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Alter table: effectively drop by changing everything",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz),
+ foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL,
+ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ fizz BOOLEAN NOT NULL,
+ buzz REAL CHECK (buzz IS NOT NULL)
+ );
+ CREATE INDEX normal_idx ON foobar USING hash (fizz);
+ CREATE UNIQUE INDEX unique_idx ON foobar(foo DESC, bar);
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foobar(
+ new_id INT PRIMARY KEY CHECK (new_id > 0), CHECK (new_id < new_buzz),
+ new_foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL,
+ new_bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ new_fizz BOOLEAN NOT NULL,
+ new_buzz REAL CHECK (new_buzz IS NOT NULL)
+ );
+ CREATE INDEX normal_idx ON foobar USING hash (new_fizz);
+ CREATE UNIQUE INDEX unique_idx ON foobar(new_foo DESC, new_bar);
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeDeletesData,
+ diff.MigrationHazardTypeIndexDropped,
+ diff.MigrationHazardTypeIndexBuild,
+ },
+ },
+ {
+ name: "Alter table: translate BIGINT type to TIMESTAMP, set to not null, set default",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE alexrhee_testing(
+ id INT PRIMARY KEY,
+ obj_attr__c_time BIGINT,
+ obj_attr__m_time BIGINT
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE alexrhee_testing(
+ id INT PRIMARY KEY,
+ obj_attr__c_time TIMESTAMP NOT NULL,
+ obj_attr__m_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeAcquiresAccessExclusiveLock,
+ diff.MigrationHazardTypeImpactsDatabasePerformance,
+ },
+ },
+}
+
+func (suite *acceptanceTestSuite) TestTableAcceptanceTestCases() {
+ suite.runTestCases(tableAcceptanceTestCases)
+}
diff --git a/internal/migration_acceptance_tests/trigger_cases_test.go b/internal/migration_acceptance_tests/trigger_cases_test.go
new file mode 100644
index 0000000..fd41da4
--- /dev/null
+++ b/internal/migration_acceptance_tests/trigger_cases_test.go
@@ -0,0 +1,702 @@
+package migration_acceptance_tests
+
+import (
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+)
+
+var triggerAcceptanceTestCases = []acceptanceTestCase{
+ {
+ name: "No-op",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE foo (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foo
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+
+ CREATE FUNCTION check_content() RETURNS TRIGGER AS $$
+ BEGIN
+ IF LENGTH(NEW.content) == 0 THEN
+ RAISE EXCEPTION 'content is empty';
+ END IF;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_check_trigger
+ BEFORE UPDATE ON foo
+ FOR EACH ROW
+ EXECUTE FUNCTION check_content();
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE foo (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_update_trigger
+ BEFORE UPDATE ON foo
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+
+ CREATE FUNCTION check_content() RETURNS TRIGGER AS $$
+ BEGIN
+ IF LENGTH(NEW.content) == 0 THEN
+ RAISE EXCEPTION 'content is empty';
+ END IF;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_check_trigger
+ BEFORE UPDATE ON foo
+ FOR EACH ROW
+ EXECUTE FUNCTION check_content();
+ `,
+ },
+ },
+ {
+ name: "Create trigger with quoted name",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Create triggers with quoted names on partitioned table and partition",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ ) PARTITION BY LIST (content);
+
+ CREATE TABLE "foobar 1" PARTITION OF "some foo" FOR VALUES IN ('foo_2');
+
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ ) PARTITION BY LIST (content);
+
+ CREATE TABLE "foobar 1" PARTITION OF "some foo" FOR VALUES IN ('foo_2');
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+
+ CREATE TRIGGER "some partition trigger"
+ BEFORE UPDATE ON "foobar 1"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Create two triggers depending on the same function",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+
+ CREATE TRIGGER "some other trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Create two triggers with the same name",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE TABLE "some other foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE TABLE "some other foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some other foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Drop trigger with quoted names",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+ `,
+ },
+ },
+ {
+ name: "Drop triggers with quoted names on partitioned table and partition",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ ) PARTITION BY LIST (content);
+
+ CREATE TABLE "foobar 1" PARTITION OF "some foo" FOR VALUES IN ('foo_2');
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+
+ CREATE TRIGGER "some partition trigger"
+ BEFORE UPDATE ON "foobar 1"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ ) PARTITION BY LIST (content);
+
+ CREATE TABLE "foobar 1" PARTITION OF "some foo" FOR VALUES IN ('foo_2');
+
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Drop two triggers with the same name",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE TABLE "some other foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some other foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE TABLE "some other foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Alter trigger when clause",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (NEW.author != 'fizz')
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ },
+ {
+ name: "Alter trigger table",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+ CREATE TABLE "some other foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+ CREATE TABLE "some other foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some other foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ },
+ {
+ name: "Change trigger function and keep old function",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `},
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+
+ CREATE FUNCTION "decrement version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "decrement version"();
+ `},
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Change trigger function and drop old function",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `},
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some foo" (
+ id INTEGER PRIMARY KEY,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT '',
+ version INT NOT NULL DEFAULT 0
+ );
+
+ CREATE FUNCTION "decrement version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foo"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "decrement version"();
+ `},
+ expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies},
+ },
+ {
+ name: "Trigger on re-created table is re-created",
+ oldSchemaDDL: []string{
+ `
+ CREATE TABLE "some foobar" (
+ id INTEGER,
+ version INT NOT NULL DEFAULT 0,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT ''
+ ) PARTITION BY LIST (content);
+
+ CREATE TABLE "foobar 1" PARTITION OF "some foobar" FOR VALUES IN ('foo_2');
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some foobar"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+
+ CREATE TRIGGER "some partition trigger"
+ BEFORE UPDATE ON "foobar 1"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ newSchemaDDL: []string{
+ `
+ CREATE TABLE "some other foobar" (
+ id INTEGER,
+ version INT NOT NULL DEFAULT 0,
+ author TEXT,
+ content TEXT NOT NULL DEFAULT ''
+ ) PARTITION BY LIST (content);
+
+ CREATE TABLE "foobar 1" PARTITION OF "some other foobar" FOR VALUES IN ('foo_2');
+
+ CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER "some trigger"
+ BEFORE UPDATE ON "some other foobar"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+
+ CREATE TRIGGER "some partition trigger"
+ BEFORE UPDATE ON "foobar 1"
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE "increment version"();
+ `,
+ },
+ expectedHazardTypes: []diff.MigrationHazardType{
+ diff.MigrationHazardTypeDeletesData,
+ },
+ },
+}
+
+func (suite *acceptanceTestSuite) TestTriggerAcceptanceTestCases() {
+ suite.runTestCases(triggerAcceptanceTestCases)
+}
diff --git a/internal/pgdump/dump.go b/internal/pgdump/dump.go
new file mode 100644
index 0000000..d63dac8
--- /dev/null
+++ b/internal/pgdump/dump.go
@@ -0,0 +1,55 @@
+package pgdump
+
+import (
+ "errors"
+ "fmt"
+ "os/exec"
+
+ "github.com/stripe/pg-schema-diff/internal/pgengine"
+)
+
+// Parameter represents a parameter to be pg_dump. Don't use a type alias for a string slice
+// because all parameters for pgdump should be explicitly added here
+type Parameter struct {
+ values []string `explicit:"always"`
+}
+
+func WithExcludeSchema(pattern string) Parameter {
+ return Parameter{
+ values: []string{"--exclude-schema", pattern},
+ }
+}
+
+func WithSchemaOnly() Parameter {
+ return Parameter{
+ values: []string{"--schema-only"},
+ }
+}
+
+// GetDump gets the pg_dump of the inputted database.
+// It is only intended to be used for testing. You cannot securely pass passwords with this implementation, so it will
+// only accept databases created for unit tests (spun up with the pgengine package)
+// "pgdump" must be on the system's PATH
+func GetDump(db *pgengine.DB, additionalParams ...Parameter) (string, error) {
+ pgDumpBinaryPath, err := exec.LookPath("pg_dump")
+ if err != nil {
+ return "", errors.New("pg_dump executable not found in path")
+ }
+ return GetDumpUsingBinary(pgDumpBinaryPath, db, additionalParams...)
+}
+
+func GetDumpUsingBinary(pgDumpBinaryPath string, db *pgengine.DB, additionalParams ...Parameter) (string, error) {
+ params := []string{
+ db.GetDSN(),
+ }
+ for _, param := range additionalParams {
+ params = append(params, param.values...)
+ }
+
+ output, err := exec.Command(pgDumpBinaryPath, params...).CombinedOutput()
+ if err != nil {
+ return "", fmt.Errorf("running pg dump \noutput=%s\n: %w", output, err)
+ }
+
+ return string(output), nil
+}
diff --git a/internal/pgdump/dump_test.go b/internal/pgdump/dump_test.go
new file mode 100644
index 0000000..ff2f130
--- /dev/null
+++ b/internal/pgdump/dump_test.go
@@ -0,0 +1,53 @@
+package pgdump_test
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/stripe/pg-schema-diff/internal/pgdump"
+ "github.com/stripe/pg-schema-diff/internal/pgengine"
+)
+
+func TestGetDump(t *testing.T) {
+ pgEngine, err := pgengine.StartEngine()
+ require.NoError(t, err)
+ defer pgEngine.Close()
+
+ db, err := pgEngine.CreateDatabase()
+ require.NoError(t, err)
+ defer db.DropDB()
+
+ connPool, err := sql.Open("pgx", db.GetDSN())
+ require.NoError(t, err)
+ defer connPool.Close()
+
+ _, err = connPool.ExecContext(context.Background(), `
+ CREATE TABLE foobar(foobar_id text);
+
+ INSERT INTO foobar VALUES ('some-id');
+
+ CREATE SCHEMA test;
+ CREATE TABLE test.bar(bar_id text);
+ `)
+ require.NoError(t, err)
+
+ dump, err := pgdump.GetDump(db)
+ require.NoError(t, err)
+ require.Contains(t, dump, "public.foobar")
+ require.Contains(t, dump, "test.bar")
+ require.Contains(t, dump, "some-id")
+
+ onlySchemasDump, err := pgdump.GetDump(db, pgdump.WithSchemaOnly())
+ require.NoError(t, err)
+ require.Contains(t, onlySchemasDump, "public.foobar")
+ require.Contains(t, onlySchemasDump, "test.bar")
+ require.NotContains(t, onlySchemasDump, "some-id")
+
+ onlyPublicSchemaDump, err := pgdump.GetDump(db, pgdump.WithSchemaOnly(), pgdump.WithExcludeSchema("test"))
+ require.NoError(t, err)
+ require.Contains(t, onlyPublicSchemaDump, "public.foobar")
+ require.NotContains(t, onlyPublicSchemaDump, "test.bar")
+ require.NotContains(t, onlyPublicSchemaDump, "some-id")
+}
diff --git a/internal/pgengine/db.go b/internal/pgengine/db.go
new file mode 100644
index 0000000..2b24105
--- /dev/null
+++ b/internal/pgengine/db.go
@@ -0,0 +1,61 @@
+package pgengine
+
+import (
+ "database/sql"
+ "fmt"
+
+ _ "github.com/jackc/pgx/v4/stdlib"
+)
+
+type DB struct {
+ connOpts ConnectionOptions
+
+ dropped bool
+}
+
+func (d *DB) GetName() string {
+ return d.connOpts[ConnectionOptionDatabase]
+}
+
+func (d *DB) GetConnOpts() ConnectionOptions {
+ return d.connOpts
+}
+
+func (d *DB) GetDSN() string {
+ return d.GetConnOpts().ToDSN()
+}
+
+// DropDB drops the database
+func (d *DB) DropDB() error {
+ if d.dropped {
+ return nil
+ }
+
+ // Use the pgDsn as we are dropping the test database
+ db, err := sql.Open("pgx", d.GetConnOpts().With(ConnectionOptionDatabase, "postgres").ToDSN())
+ if err != nil {
+ return err
+ }
+ defer db.Close()
+
+ // Disallow further connections to the test database, except for superusers
+ _, err = db.Exec(fmt.Sprintf("ALTER DATABASE \"%s\" CONNECTION LIMIT 0", d.GetName()))
+ if err != nil {
+ return err
+ }
+
+ // Drop existing connections, so that we can drop the table
+ _, err = db.Exec("SELECT PG_TERMINATE_BACKEND(pid) FROM pg_stat_activity WHERE datname = $1", d.GetName())
+ if err != nil {
+ return err
+ }
+
+ // Finally, drop the table
+ _, err = db.Exec(fmt.Sprintf("DROP DATABASE \"%s\"", d.GetName()))
+ if err != nil {
+ return err
+ }
+
+ d.dropped = true
+ return nil
+}
diff --git a/internal/pgengine/engine.go b/internal/pgengine/engine.go
new file mode 100644
index 0000000..b5bcaa1
--- /dev/null
+++ b/internal/pgengine/engine.go
@@ -0,0 +1,235 @@
+package pgengine
+
+import (
+ "database/sql"
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "os/user"
+ "path"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ _ "github.com/jackc/pgx/v4/stdlib"
+)
+
+const (
+ port = 5432
+
+ maxConnAttemptsAtStartup = 10
+ waitBetweenStartupConnAttempt = time.Second
+)
+
+type ConnectionOption string
+
+const (
+ ConnectionOptionDatabase ConnectionOption = "dbname"
+)
+
+type ConnectionOptions map[ConnectionOption]string
+
+func (c ConnectionOptions) With(option ConnectionOption, value string) ConnectionOptions {
+ clone := make(ConnectionOptions)
+ for k, v := range c {
+ clone[k] = v
+ }
+ clone[option] = value
+ return clone
+}
+
+func (c ConnectionOptions) ToDSN() string {
+ var pairs []string
+ for k, v := range c {
+ pairs = append(pairs, fmt.Sprintf("%s%s%s", k, "=", v))
+ }
+
+ return strings.Join(pairs, " ")
+}
+
+type Engine struct {
+ user *user.User
+
+ // for cleanup purposes
+ process *os.Process
+ dbPath string
+ sockPath string
+}
+
+// StartEngine starts a postgres instance. This is useful for testing, where Postgres databases need to be spun up.
+// "postgres" must be on the system's PATH, and the binary must be located in a directory containing "initdb"
+func StartEngine() (*Engine, error) {
+ postgresPath, err := exec.LookPath("postgres")
+ if err != nil {
+ return nil, errors.New("postgres executable not found in path")
+ }
+ return StartEngineUsingPgDir(path.Dir(postgresPath))
+}
+
+func StartEngineUsingPgDir(pgDir string) (*Engine, error) {
+ currentUser, err := user.Current()
+ if err != nil {
+ return nil, err
+ }
+
+ dbPath, err := os.MkdirTemp("", "postgresql-")
+ if err != nil {
+ return nil, err
+ }
+
+ sockPath, err := os.MkdirTemp("", "pgsock-")
+ if err != nil {
+ return nil, err
+ }
+
+ if err := initDB(currentUser, path.Join(pgDir, "initdb"), dbPath); err != nil {
+ return nil, err
+ }
+
+ process, err := startServer(path.Join(pgDir, "postgres"), dbPath, sockPath)
+ if err != nil {
+ // Cleanup temporary directories that were created
+ os.RemoveAll(dbPath)
+ os.RemoveAll(sockPath)
+ return nil, err
+ }
+
+ pgEngine := &Engine{
+ dbPath: dbPath,
+ sockPath: sockPath,
+ user: currentUser,
+ process: process,
+ }
+ if err := pgEngine.waitTillServingTraffic(maxConnAttemptsAtStartup, waitBetweenStartupConnAttempt); err != nil {
+ pgEngine.Close()
+ return nil, fmt.Errorf("waiting till server can serve traffic: %w", err)
+ }
+ return pgEngine, nil
+}
+
+func initDB(currentUser *user.User, initDbPath, dbPath string) error {
+ cmd := exec.Command(initDbPath, []string{
+ "-U", currentUser.Username,
+ "-D", dbPath,
+ "-A", "trust",
+ }...)
+
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ outputStr := string(output)
+
+ var tip string
+ line := strings.Repeat("=", 95)
+ if strings.Contains(outputStr, "request for a shared memory segment exceeded your kernel's SHMALL parameter") {
+ tip = line + "\n Run 'sudo sysctl -w kern.sysv.shmall=16777216' to solve this issue \n" + line + "\n"
+ } else if strings.Contains(outputStr, "could not create shared memory segment: No space left on device") {
+ tip = line + "\n Use the ipcs and ipcrm commands to clear the shared memory \n" + line + "\n"
+ }
+
+ return fmt.Errorf("error running initdb: %w\n%s\n%s", err, outputStr, tip)
+ }
+ return nil
+}
+
+func startServer(pgBinaryPath, dbPath, sockPath string) (*os.Process, error) {
+ cmd := exec.Command(pgBinaryPath, []string{
+ "-D", dbPath,
+ "-k", sockPath,
+ "-p", strconv.Itoa(port),
+ "-h", ""}...)
+
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if err := cmd.Start(); err != nil {
+ return nil, fmt.Errorf("starting postgres server instance: %w", err)
+ }
+
+ return cmd.Process, nil
+}
+
+func (e *Engine) waitTillServingTraffic(maxAttempts int, timeBetweenAttempts time.Duration) error {
+ var mostRecentErr error
+ for i := 0; i < maxAttempts; i++ {
+ mostRecentErr = e.testIfInstanceServingTraffic()
+ if mostRecentErr == nil {
+ return nil
+ }
+ time.Sleep(timeBetweenAttempts)
+ }
+ return fmt.Errorf("unable to establish connection to postgres instance. most recent error: %w", mostRecentErr)
+}
+
+func (e *Engine) testIfInstanceServingTraffic() error {
+ dsn := e.GetPostgresDatabaseConnOpts().ToDSN()
+ db, err := sql.Open("pgx", dsn)
+ if err != nil {
+ return err
+ }
+
+ if err := db.Ping(); err != nil {
+ db.Close()
+ return err
+ }
+ // Don't `defer` the call to DropDB(), since we want to return the error if any
+ return db.Close()
+}
+
+func (e *Engine) GetPostgresDatabaseConnOpts() ConnectionOptions {
+ result := make(map[ConnectionOption]string)
+ result[ConnectionOptionDatabase] = "postgres"
+ result["host"] = e.sockPath
+ result["port"] = strconv.Itoa(port)
+ result["sslmode"] = "disable"
+
+ return result
+}
+
+func (e *Engine) GetPostgresDatabaseDSN() string {
+ return e.GetPostgresDatabaseConnOpts().ToDSN()
+}
+
+func (e *Engine) Close() error {
+ // Make best effort attempt to clean up everything
+ e.process.Signal(os.Interrupt)
+ e.process.Wait()
+ os.RemoveAll(e.dbPath)
+ os.RemoveAll(e.dbPath)
+
+ return nil
+}
+
+func (e *Engine) CreateDatabase() (*DB, error) {
+ uuid, err := uuid.NewRandom()
+ if err != nil {
+ return nil, fmt.Errorf("generating uuid: %w", err)
+ }
+ testDBName := fmt.Sprintf("pgtestdb_%v", uuid.String())
+
+ testDb, err := e.CreateDatabaseWithName(testDBName)
+ if err != nil {
+ return nil, err
+ }
+
+ return testDb, err
+}
+
+func (e *Engine) CreateDatabaseWithName(name string) (*DB, error) {
+ dsn := e.GetPostgresDatabaseConnOpts().With(ConnectionOptionDatabase, "postgres").ToDSN()
+ db, err := sql.Open("pgx", dsn)
+ if err != nil {
+ return nil, err
+ }
+ defer db.Close()
+
+ _, err = db.Exec(fmt.Sprintf("CREATE DATABASE \"%s\"", name))
+ if err != nil {
+ return nil, err
+ }
+
+ return &DB{
+ connOpts: e.GetPostgresDatabaseConnOpts().With(ConnectionOptionDatabase, name),
+ }, nil
+}
diff --git a/internal/pgengine/engine_test.go b/internal/pgengine/engine_test.go
new file mode 100644
index 0000000..a5da12c
--- /dev/null
+++ b/internal/pgengine/engine_test.go
@@ -0,0 +1,105 @@
+package pgengine_test
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stripe/pg-schema-diff/internal/pgengine"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestPgEngine(t *testing.T) {
+ engine, err := pgengine.StartEngine()
+ require.NoError(t, err)
+ defer func() {
+ // Drops should be idempotent
+ require.NoError(t, engine.Close())
+ require.NoError(t, engine.Close())
+ }()
+
+ unnamedDb, err := engine.CreateDatabase()
+ require.NoError(t, err)
+ assertDatabaseIsValid(t, unnamedDb, "")
+
+ namedDb, err := engine.CreateDatabaseWithName("some-name")
+ require.NoError(t, err)
+ assertDatabaseIsValid(t, namedDb, "some-name")
+
+ // Assert no extra databases were created
+ assert.ElementsMatch(t, getAllDatabaseNames(t, engine), []string{unnamedDb.GetName(), "some-name", "postgres", "template0", "template1"})
+
+ // Hold open a connection before we try to drop the database. The drop should still pass, despite the open
+ // connection
+ connPool, err := sql.Open("pgx", unnamedDb.GetDSN())
+ require.NoError(t, err)
+ defer connPool.Close()
+ conn, err := connPool.Conn(context.Background())
+ require.NoError(t, err)
+ defer conn.Close()
+
+ // Drops should be idempotent
+ require.NoError(t, unnamedDb.DropDB())
+ require.NoError(t, unnamedDb.DropDB())
+ require.NoError(t, namedDb.DropDB())
+ require.NoError(t, namedDb.DropDB())
+
+ // Assert only the expected databases were dropped
+ assert.ElementsMatch(t, getAllDatabaseNames(t, engine), []string{"postgres", "template0", "template1"})
+}
+
+func assertDatabaseIsValid(t *testing.T, db *pgengine.DB, expectedName string) {
+ connPool, err := sql.Open("pgx", db.GetDSN())
+ require.NoError(t, err)
+ defer func() {
+ require.NoError(t, connPool.Close())
+ }()
+
+ var fetchedName string
+ require.NoError(t, connPool.QueryRowContext(context.Background(), "SELECT current_database();").
+ Scan(&fetchedName))
+
+ assert.Equal(t, db.GetName(), fetchedName)
+ if len(expectedName) > 0 {
+ assert.Equal(t, fetchedName, expectedName)
+ }
+
+ // Validate writing and reading works
+ _, err = connPool.ExecContext(context.Background(), `CREATE TABLE foobar(id serial NOT NULL)`)
+ require.NoError(t, err)
+
+ _, err = connPool.ExecContext(context.Background(), `INSERT INTO foobar DEFAULT VALUES`)
+ require.NoError(t, err)
+
+ var id string
+ require.NoError(t, connPool.QueryRowContext(context.Background(), "SELECT * FROM foobar LIMIT 1;").
+ Scan(&id))
+ assert.NotEmptyf(t, t, id)
+}
+
+func getAllDatabaseNames(t *testing.T, engine *pgengine.Engine) []string {
+ conn, err := sql.Open("pgx", engine.GetPostgresDatabaseDSN())
+ require.NoError(t, err)
+ defer func() {
+ require.NoError(t, conn.Close())
+ }()
+
+ rows, err := conn.QueryContext(context.Background(), "SELECT datname FROM pg_database")
+ require.NoError(t, err)
+ defer rows.Close()
+
+ // Iterate over the rows and put the database names in an array
+ var dbNames []string
+ for rows.Next() {
+ var dbName string
+ if err := rows.Scan(&dbName); err != nil {
+ require.NoError(t, err)
+ }
+ dbNames = append(dbNames, dbName)
+ }
+ require.NoError(t, rows.Err())
+
+ return dbNames
+}
diff --git a/internal/pgidentifier/identifier.go b/internal/pgidentifier/identifier.go
new file mode 100644
index 0000000..d50d47c
--- /dev/null
+++ b/internal/pgidentifier/identifier.go
@@ -0,0 +1,12 @@
+package pgidentifier
+
+import "regexp"
+
+var (
+ // SimpleIdentifierRegex matches identifiers in Postgres that require no quotes
+ SimpleIdentifierRegex = regexp.MustCompile("^[a-z_][a-z0-9_$]*$")
+)
+
+func IsSimpleIdentifier(val string) bool {
+ return SimpleIdentifierRegex.MatchString(val)
+}
diff --git a/internal/pgidentifier/identifier_test.go b/internal/pgidentifier/identifier_test.go
new file mode 100644
index 0000000..4efc384
--- /dev/null
+++ b/internal/pgidentifier/identifier_test.go
@@ -0,0 +1,56 @@
+package pgidentifier
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func Test_IsSimpleIdentifier(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ input string
+ expected bool
+ }{
+ {
+ name: "starts with letter",
+ input: "foo",
+ expected: true,
+ },
+ {
+ name: "starts with underscore",
+ input: "_foo",
+ expected: true,
+ },
+ {
+ name: "start with number",
+ input: "1foo",
+ expected: false,
+ },
+ {
+ name: "contains all possible characters",
+ input: "some_1119$_",
+ expected: true,
+ },
+ {
+ name: "empty",
+ input: "",
+ expected: false,
+ },
+ {
+ name: "contains upper case letter",
+ input: "fooBar",
+ expected: false,
+ },
+ {
+ name: "contains spaces",
+ input: "foo bar",
+ expected: false,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ actual := IsSimpleIdentifier(tc.input)
+ assert.Equal(t, tc.expected, actual)
+ })
+ }
+}
diff --git a/internal/queries/README.md b/internal/queries/README.md
new file mode 100644
index 0000000..74f4f1f
--- /dev/null
+++ b/internal/queries/README.md
@@ -0,0 +1,15 @@
+# Expected workflow
+
+1. Add a SQL query/statement in `queries.sql`, following the other examples.
+2. Run `make sqlc` using the same version of sqlc as in `build/Dockerfile.codegen`
+
+`sqlc` decides what the return-type of the generated method should be based on the `:exec` suffix (documentation [here](https://docs.sqlc.dev/en/latest/reference/query-annotations.html)):
+ - `:exec` will only tell you whether the query succeeded: `error`
+ - `:execrows` will tell you how many rows were affected: `(int64, error)`
+ - `:one` will give you back a single struct: `(Author, error)`
+ - `:many` will give you back a slice of structs: `([]Author, error)`
+
+
+It is configured by the `sqlc.yaml` file. You can read docs about the various
+options it supports
+[here](https://docs.sqlc.dev/en/latest/reference/config.html).
diff --git a/internal/queries/dml.sql.go b/internal/queries/dml.sql.go
new file mode 100644
index 0000000..9df9a79
--- /dev/null
+++ b/internal/queries/dml.sql.go
@@ -0,0 +1,31 @@
+// Code generated by sqlc. DO NOT EDIT.
+// versions:
+// sqlc v1.13.0
+
+package queries
+
+import (
+ "context"
+ "database/sql"
+)
+
+type DBTX interface {
+ ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
+ PrepareContext(context.Context, string) (*sql.Stmt, error)
+ QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
+ QueryRowContext(context.Context, string, ...interface{}) *sql.Row
+}
+
+func New(db DBTX) *Queries {
+ return &Queries{db: db}
+}
+
+type Queries struct {
+ db DBTX
+}
+
+func (q *Queries) WithTx(tx *sql.Tx) *Queries {
+ return &Queries{
+ db: tx,
+ }
+}
diff --git a/internal/queries/models.sql.go b/internal/queries/models.sql.go
new file mode 100644
index 0000000..2a05a43
--- /dev/null
+++ b/internal/queries/models.sql.go
@@ -0,0 +1,7 @@
+// Code generated by sqlc. DO NOT EDIT.
+// versions:
+// sqlc v1.13.0
+
+package queries
+
+import ()
diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql
new file mode 100644
index 0000000..5e90293
--- /dev/null
+++ b/internal/queries/queries.sql
@@ -0,0 +1,121 @@
+-- name: GetTables :many
+SELECT c.oid AS oid,
+ c.relname AS table_name,
+ COALESCE(parent_c.relname, '') AS parent_table_name,
+ COALESCE(parent_namespace.nspname, '') AS parent_table_schema_name,
+ (CASE
+ WHEN c.relkind = 'p' THEN pg_catalog.pg_get_partkeydef(c.oid)
+ ELSE ''
+ END)::text
+ AS partition_key_def,
+ (CASE
+ WHEN c.relispartition THEN pg_catalog.pg_get_expr(c.relpartbound, c.oid)
+ ELSE ''
+ END)::text AS partition_for_values
+FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_inherits inherits ON inherits.inhrelid = c.oid
+ LEFT JOIN pg_catalog.pg_class parent_c ON inherits.inhparent = parent_c.oid
+ LEFT JOIN pg_catalog.pg_namespace as parent_namespace ON parent_c.relnamespace = parent_namespace.oid
+WHERE c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public')
+ AND (c.relkind = 'r' OR c.relkind = 'p');
+
+-- name: GetColumnsForTable :many
+SELECT a.attname AS column_name,
+ pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type,
+ COALESCE(coll.collname, '') AS collation_name,
+ COALESCE(collation_namespace.nspname, '') AS collation_schema_name,
+ COALESCE(pg_catalog.pg_get_expr(d.adbin, d.adrelid), '')::TEXT AS default_value,
+ a.attnotnull AS is_not_null,
+ a.attlen AS column_size
+FROM pg_catalog.pg_attribute a
+ LEFT JOIN pg_catalog.pg_attrdef d ON (d.adrelid = a.attrelid AND d.adnum = a.attnum)
+ LEFT JOIN pg_catalog.pg_collation coll ON coll.oid = a.attcollation
+ LEFT JOIN pg_catalog.pg_namespace collation_namespace ON collation_namespace.oid = coll.collnamespace
+WHERE a.attrelid = $1
+ AND a.attnum > 0
+ AND NOT a.attisdropped
+ORDER BY a.attnum;
+
+-- name: GetIndexes :many
+SELECT c.oid AS oid,
+ c.relname as index_name,
+ (SELECT c1.relname AS table_name FROM pg_catalog.pg_class c1 WHERE c1.oid = i.indrelid),
+ pg_catalog.pg_get_indexdef(c.oid)::TEXT as def_stmt,
+ COALESCE(con.conname, '') as constraint_name,
+ i.indisvalid as index_is_valid,
+ i.indisprimary as index_is_pk,
+ i.indisunique AS index_is_unique,
+ COALESCE(parent_c.relname, '') as parent_index_name,
+ COALESCE(parent_namespace.nspname, '') as parent_index_schema_name
+FROM pg_catalog.pg_class c
+ INNER JOIN pg_catalog.pg_index i ON (i.indexrelid = c.oid)
+ LEFT JOIN pg_catalog.pg_constraint con ON (con.conindid = c.oid)
+ LEFT JOIN pg_catalog.pg_inherits inherits ON (c.oid = inherits.inhrelid)
+ LEFT JOIN pg_catalog.pg_class parent_c ON (inherits.inhparent = parent_c.oid)
+ LEFT JOIN pg_catalog.pg_namespace as parent_namespace ON parent_c.relnamespace = parent_namespace.oid
+WHERE c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public')
+ AND (c.relkind = 'i' OR c.relkind = 'I');
+
+-- name: GetColumnsForIndex :many
+SELECT a.attname AS column_name
+FROM pg_catalog.pg_attribute a
+WHERE a.attrelid = $1
+ AND a.attnum > 0
+ORDER BY a.attnum;
+
+-- name: GetCheckConstraints :many
+SELECT pg_constraint.oid,
+ conname as name,
+ pg_class.relname as table_name,
+ pg_catalog.pg_get_expr(conbin, conrelid) as expression,
+ convalidated as is_valid,
+ connoinherit as is_not_inheritable
+FROM pg_catalog.pg_constraint
+ JOIN pg_catalog.pg_class ON pg_constraint.conrelid = pg_class.oid
+WHERE pg_class.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public')
+ AND contype = 'c'
+ AND pg_constraint.conislocal;
+
+-- name: GetFunctions :many
+SELECT proc.oid,
+ proname as func_name,
+ pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments,
+ proc_namespace.nspname as func_schema_name,
+ pg_catalog.pg_get_functiondef(proc.oid) as func_def,
+ proc_lang.lanname as func_lang
+FROM pg_catalog.pg_proc proc
+ JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid
+ JOIN pg_catalog.pg_language proc_lang ON proc_lang.oid = proc.prolang
+WHERE proc_namespace.nspname = 'public'
+ AND proc.prokind = 'f'
+ -- Exclude functions belonging to extensions
+ AND NOT EXISTS(SELECT depend.objid FROM pg_catalog.pg_depend depend WHERE deptype = 'e' AND depend.objid = proc.oid);
+
+-- name: GetDependsOnFunctions :many
+SELECT proc.proname as func_name,
+ pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments,
+ proc_namespace.nspname as func_schema_name
+FROM pg_catalog.pg_depend depend
+ JOIN pg_catalog.pg_proc proc ON depend.refobjid = proc.oid
+ JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid
+WHERE depend.objid = $1
+ AND depend.deptype = 'n';
+
+-- name: GetTriggers :many
+SELECT trig.tgname as trigger_name,
+ owning_c.relname as owning_table_name,
+ owning_c_namespace.nspname as owning_table_schema_name,
+ proc.proname as func_name,
+ pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments,
+ proc_namespace.nspname as func_schema_name,
+ pg_catalog.pg_get_triggerdef(trig.oid) as trigger_def
+FROM pg_catalog.pg_trigger trig
+ JOIN pg_catalog.pg_class owning_c ON trig.tgrelid = owning_c.oid
+ JOIN pg_catalog.pg_namespace owning_c_namespace ON owning_c.relnamespace = owning_c_namespace.oid
+ JOIN pg_catalog.pg_proc proc ON trig.tgfoid = proc.oid
+ JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid
+WHERE proc_namespace.nspname = 'public'
+ AND owning_c_namespace.nspname = 'public'
+ AND trig.tgparentid = 0
+ AND NOT trig.tgisinternal;
+
diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go
new file mode 100644
index 0000000..57e62dd
--- /dev/null
+++ b/internal/queries/queries.sql.go
@@ -0,0 +1,437 @@
+// Code generated by sqlc. DO NOT EDIT.
+// versions:
+// sqlc v1.13.0
+// source: queries.sql
+
+package queries
+
+import (
+ "context"
+)
+
+const getCheckConstraints = `-- name: GetCheckConstraints :many
+SELECT pg_constraint.oid,
+ conname as name,
+ pg_class.relname as table_name,
+ pg_catalog.pg_get_expr(conbin, conrelid) as expression,
+ convalidated as is_valid,
+ connoinherit as is_not_inheritable
+FROM pg_catalog.pg_constraint
+ JOIN pg_catalog.pg_class ON pg_constraint.conrelid = pg_class.oid
+WHERE pg_class.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public')
+ AND contype = 'c'
+ AND pg_constraint.conislocal
+`
+
+type GetCheckConstraintsRow struct {
+ Oid interface{}
+ Name string
+ TableName string
+ Expression string
+ IsValid bool
+ IsNotInheritable bool
+}
+
+func (q *Queries) GetCheckConstraints(ctx context.Context) ([]GetCheckConstraintsRow, error) {
+ rows, err := q.db.QueryContext(ctx, getCheckConstraints)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []GetCheckConstraintsRow
+ for rows.Next() {
+ var i GetCheckConstraintsRow
+ if err := rows.Scan(
+ &i.Oid,
+ &i.Name,
+ &i.TableName,
+ &i.Expression,
+ &i.IsValid,
+ &i.IsNotInheritable,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const getColumnsForIndex = `-- name: GetColumnsForIndex :many
+SELECT a.attname AS column_name
+FROM pg_catalog.pg_attribute a
+WHERE a.attrelid = $1
+ AND a.attnum > 0
+ORDER BY a.attnum
+`
+
+func (q *Queries) GetColumnsForIndex(ctx context.Context, attrelid interface{}) ([]string, error) {
+ rows, err := q.db.QueryContext(ctx, getColumnsForIndex, attrelid)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []string
+ for rows.Next() {
+ var column_name string
+ if err := rows.Scan(&column_name); err != nil {
+ return nil, err
+ }
+ items = append(items, column_name)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const getColumnsForTable = `-- name: GetColumnsForTable :many
+SELECT a.attname AS column_name,
+ pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type,
+ COALESCE(coll.collname, '') AS collation_name,
+ COALESCE(collation_namespace.nspname, '') AS collation_schema_name,
+ COALESCE(pg_catalog.pg_get_expr(d.adbin, d.adrelid), '')::TEXT AS default_value,
+ a.attnotnull AS is_not_null,
+ a.attlen AS column_size
+FROM pg_catalog.pg_attribute a
+ LEFT JOIN pg_catalog.pg_attrdef d ON (d.adrelid = a.attrelid AND d.adnum = a.attnum)
+ LEFT JOIN pg_catalog.pg_collation coll ON coll.oid = a.attcollation
+ LEFT JOIN pg_catalog.pg_namespace collation_namespace ON collation_namespace.oid = coll.collnamespace
+WHERE a.attrelid = $1
+ AND a.attnum > 0
+ AND NOT a.attisdropped
+ORDER BY a.attnum
+`
+
+type GetColumnsForTableRow struct {
+ ColumnName string
+ ColumnType string
+ CollationName string
+ CollationSchemaName string
+ DefaultValue string
+ IsNotNull bool
+ ColumnSize int16
+}
+
+func (q *Queries) GetColumnsForTable(ctx context.Context, attrelid interface{}) ([]GetColumnsForTableRow, error) {
+ rows, err := q.db.QueryContext(ctx, getColumnsForTable, attrelid)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []GetColumnsForTableRow
+ for rows.Next() {
+ var i GetColumnsForTableRow
+ if err := rows.Scan(
+ &i.ColumnName,
+ &i.ColumnType,
+ &i.CollationName,
+ &i.CollationSchemaName,
+ &i.DefaultValue,
+ &i.IsNotNull,
+ &i.ColumnSize,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const getDependsOnFunctions = `-- name: GetDependsOnFunctions :many
+SELECT proc.proname as func_name,
+ pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments,
+ proc_namespace.nspname as func_schema_name
+FROM pg_catalog.pg_depend depend
+ JOIN pg_catalog.pg_proc proc ON depend.refobjid = proc.oid
+ JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid
+WHERE depend.objid = $1
+ AND depend.deptype = 'n'
+`
+
+type GetDependsOnFunctionsRow struct {
+ FuncName string
+ FuncIdentityArguments string
+ FuncSchemaName string
+}
+
+func (q *Queries) GetDependsOnFunctions(ctx context.Context, objid interface{}) ([]GetDependsOnFunctionsRow, error) {
+ rows, err := q.db.QueryContext(ctx, getDependsOnFunctions, objid)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []GetDependsOnFunctionsRow
+ for rows.Next() {
+ var i GetDependsOnFunctionsRow
+ if err := rows.Scan(&i.FuncName, &i.FuncIdentityArguments, &i.FuncSchemaName); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const getFunctions = `-- name: GetFunctions :many
+SELECT proc.oid,
+ proname as func_name,
+ pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments,
+ proc_namespace.nspname as func_schema_name,
+ pg_catalog.pg_get_functiondef(proc.oid) as func_def,
+ proc_lang.lanname as func_lang
+FROM pg_catalog.pg_proc proc
+ JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid
+ JOIN pg_catalog.pg_language proc_lang ON proc_lang.oid = proc.prolang
+WHERE proc_namespace.nspname = 'public'
+ AND proc.prokind = 'f'
+ -- Exclude functions belonging to extensions
+ AND NOT EXISTS(SELECT depend.objid FROM pg_catalog.pg_depend depend WHERE deptype = 'e' AND depend.objid = proc.oid)
+`
+
+type GetFunctionsRow struct {
+ Oid interface{}
+ FuncName string
+ FuncIdentityArguments string
+ FuncSchemaName string
+ FuncDef string
+ FuncLang string
+}
+
+func (q *Queries) GetFunctions(ctx context.Context) ([]GetFunctionsRow, error) {
+ rows, err := q.db.QueryContext(ctx, getFunctions)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []GetFunctionsRow
+ for rows.Next() {
+ var i GetFunctionsRow
+ if err := rows.Scan(
+ &i.Oid,
+ &i.FuncName,
+ &i.FuncIdentityArguments,
+ &i.FuncSchemaName,
+ &i.FuncDef,
+ &i.FuncLang,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const getIndexes = `-- name: GetIndexes :many
+SELECT c.oid AS oid,
+ c.relname as index_name,
+ (SELECT c1.relname AS table_name FROM pg_catalog.pg_class c1 WHERE c1.oid = i.indrelid),
+ pg_catalog.pg_get_indexdef(c.oid)::TEXT as def_stmt,
+ COALESCE(con.conname, '') as constraint_name,
+ i.indisvalid as index_is_valid,
+ i.indisprimary as index_is_pk,
+ i.indisunique AS index_is_unique,
+ COALESCE(parent_c.relname, '') as parent_index_name,
+ COALESCE(parent_namespace.nspname, '') as parent_index_schema_name
+FROM pg_catalog.pg_class c
+ INNER JOIN pg_catalog.pg_index i ON (i.indexrelid = c.oid)
+ LEFT JOIN pg_catalog.pg_constraint con ON (con.conindid = c.oid)
+ LEFT JOIN pg_catalog.pg_inherits inherits ON (c.oid = inherits.inhrelid)
+ LEFT JOIN pg_catalog.pg_class parent_c ON (inherits.inhparent = parent_c.oid)
+ LEFT JOIN pg_catalog.pg_namespace as parent_namespace ON parent_c.relnamespace = parent_namespace.oid
+WHERE c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public')
+ AND (c.relkind = 'i' OR c.relkind = 'I')
+`
+
+type GetIndexesRow struct {
+ Oid interface{}
+ IndexName string
+ TableName string
+ DefStmt string
+ ConstraintName string
+ IndexIsValid bool
+ IndexIsPk bool
+ IndexIsUnique bool
+ ParentIndexName string
+ ParentIndexSchemaName string
+}
+
+func (q *Queries) GetIndexes(ctx context.Context) ([]GetIndexesRow, error) {
+ rows, err := q.db.QueryContext(ctx, getIndexes)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []GetIndexesRow
+ for rows.Next() {
+ var i GetIndexesRow
+ if err := rows.Scan(
+ &i.Oid,
+ &i.IndexName,
+ &i.TableName,
+ &i.DefStmt,
+ &i.ConstraintName,
+ &i.IndexIsValid,
+ &i.IndexIsPk,
+ &i.IndexIsUnique,
+ &i.ParentIndexName,
+ &i.ParentIndexSchemaName,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const getTables = `-- name: GetTables :many
+SELECT c.oid AS oid,
+ c.relname AS table_name,
+ COALESCE(parent_c.relname, '') AS parent_table_name,
+ COALESCE(parent_namespace.nspname, '') AS parent_table_schema_name,
+ (CASE
+ WHEN c.relkind = 'p' THEN pg_catalog.pg_get_partkeydef(c.oid)
+ ELSE ''
+ END)::text
+ AS partition_key_def,
+ (CASE
+ WHEN c.relispartition THEN pg_catalog.pg_get_expr(c.relpartbound, c.oid)
+ ELSE ''
+ END)::text AS partition_for_values
+FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_inherits inherits ON inherits.inhrelid = c.oid
+ LEFT JOIN pg_catalog.pg_class parent_c ON inherits.inhparent = parent_c.oid
+ LEFT JOIN pg_catalog.pg_namespace as parent_namespace ON parent_c.relnamespace = parent_namespace.oid
+WHERE c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public')
+ AND (c.relkind = 'r' OR c.relkind = 'p')
+`
+
+type GetTablesRow struct {
+ Oid interface{}
+ TableName string
+ ParentTableName string
+ ParentTableSchemaName string
+ PartitionKeyDef string
+ PartitionForValues string
+}
+
+func (q *Queries) GetTables(ctx context.Context) ([]GetTablesRow, error) {
+ rows, err := q.db.QueryContext(ctx, getTables)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []GetTablesRow
+ for rows.Next() {
+ var i GetTablesRow
+ if err := rows.Scan(
+ &i.Oid,
+ &i.TableName,
+ &i.ParentTableName,
+ &i.ParentTableSchemaName,
+ &i.PartitionKeyDef,
+ &i.PartitionForValues,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const getTriggers = `-- name: GetTriggers :many
+SELECT trig.tgname as trigger_name,
+ owning_c.relname as owning_table_name,
+ owning_c_namespace.nspname as owning_table_schema_name,
+ proc.proname as func_name,
+ pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments,
+ proc_namespace.nspname as func_schema_name,
+ pg_catalog.pg_get_triggerdef(trig.oid) as trigger_def
+FROM pg_catalog.pg_trigger trig
+ JOIN pg_catalog.pg_class owning_c ON trig.tgrelid = owning_c.oid
+ JOIN pg_catalog.pg_namespace owning_c_namespace ON owning_c.relnamespace = owning_c_namespace.oid
+ JOIN pg_catalog.pg_proc proc ON trig.tgfoid = proc.oid
+ JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid
+WHERE proc_namespace.nspname = 'public'
+ AND owning_c_namespace.nspname = 'public'
+ AND trig.tgparentid = 0
+ AND NOT trig.tgisinternal
+`
+
+type GetTriggersRow struct {
+ TriggerName string
+ OwningTableName string
+ OwningTableSchemaName string
+ FuncName string
+ FuncIdentityArguments string
+ FuncSchemaName string
+ TriggerDef string
+}
+
+func (q *Queries) GetTriggers(ctx context.Context) ([]GetTriggersRow, error) {
+ rows, err := q.db.QueryContext(ctx, getTriggers)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []GetTriggersRow
+ for rows.Next() {
+ var i GetTriggersRow
+ if err := rows.Scan(
+ &i.TriggerName,
+ &i.OwningTableName,
+ &i.OwningTableSchemaName,
+ &i.FuncName,
+ &i.FuncIdentityArguments,
+ &i.FuncSchemaName,
+ &i.TriggerDef,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
diff --git a/internal/queries/sqlc.yaml b/internal/queries/sqlc.yaml
new file mode 100644
index 0000000..d1c568a
--- /dev/null
+++ b/internal/queries/sqlc.yaml
@@ -0,0 +1,10 @@
+version: 1
+packages:
+ - path: "."
+ name: "queries"
+ engine: "postgresql"
+ queries: "queries.sql"
+ schema: "system_tables.sql"
+ output_db_file_name: "dml.sql.go"
+ output_models_file_name: "models.sql.go"
+ output_querier_file_name: "querier.sql.go"
diff --git a/internal/queries/system_tables.sql b/internal/queries/system_tables.sql
new file mode 100644
index 0000000..b3cf533
--- /dev/null
+++ b/internal/queries/system_tables.sql
@@ -0,0 +1,385 @@
+-- postgres=# \d pg_catalog.pg_namespace;
+-- Table "pg_catalog.pg_namespace"
+-- Column | Type | Collation | Nullable | Default
+-- ----------+-----------+-----------+----------+---------
+-- oid | oid | | not null |
+-- nspname | name | | not null |
+-- nspowner | oid | | not null |
+-- nspacl | aclitem[] | | |
+
+CREATE TABLE pg_catalog.pg_namespace
+(
+ oid OID NOT NULL,
+ nspname TEXT NOT NULL
+);
+
+
+-- postgres=# \d pg_catalog.pg_class;
+-- Table "pg_catalog.pg_class"
+-- Column | Type | Collation | Nullable | Default
+-- ---------------------+--------------+-----------+----------+---------
+-- oid | oid | | not null |
+-- relname | name | | not null |
+-- relnamespace | oid | | not null |
+-- reltype | oid | | not null |
+-- reloftype | oid | | not null |
+-- relowner | oid | | not null |
+-- relam | oid | | not null |
+-- relfilenode | oid | | not null |
+-- reltablespace | oid | | not null |
+-- relpages | integer | | not null |
+-- reltuples | real | | not null |
+-- relallvisible | integer | | not null |
+-- reltoastrelid | oid | | not null |
+-- relhasindex | boolean | | not null |
+-- relisshared | boolean | | not null |
+-- relpersistence | "char" | | not null |
+-- relkind | "char" | | not null |
+-- relnatts | smallint | | not null |
+-- relchecks | smallint | | not null |
+-- relhasrules | boolean | | not null |
+-- relhastriggers | boolean | | not null |
+-- relhassubclass | boolean | | not null |
+-- relrowsecurity | boolean | | not null |
+-- relforcerowsecurity | boolean | | not null |
+-- relispopulated | boolean | | not null |
+-- relreplident | "char" | | not null |
+-- relispartition | boolean | | not null |
+-- relrewrite | oid | | not null |
+-- relfrozenxid | xid | | not null |
+-- relminmxid | xid | | not null |
+-- relacl | aclitem[] | | |
+-- reloptions | text[] | C | |
+-- relpartbound | pg_node_tree | C | |
+-- Indexes:
+-- "pg_class_oid_index" PRIMARY KEY, btree (oid)
+-- "pg_class_relname_nsp_index" UNIQUE CONSTRAINT, btree (relname, relnamespace)
+-- "pg_class_tblspc_relfilenode_index" btree (reltablespace, relfilenode)
+
+CREATE TABLE pg_catalog.pg_class
+(
+ oid OID NOT NULL,
+ relname TEXT NOT NULL,
+ relnamespace INT NOT NULL,
+ reltype INT NOT NULL
+);
+
+-- postgres=# \d pg_catalog.pg_inherits
+-- Column | Type | Collation | Nullable | Default
+-- ------------------+---------+-----------+----------+---------
+-- inhrelid | oid | | not null |
+-- inhparent | oid | | not null |
+-- inhseqno | integer | | not null |
+-- inhdetachpending | boolean | | not null |
+-- Indexes:
+-- "pg_inherits_relid_seqno_index" PRIMARY KEY, btree (inhrelid, inhseqno)
+-- "pg_inherits_parent_index" btree (inhparent)
+--
+CREATE TABLE pg_catalog.pg_inherits
+(
+ inhrelid OID NOT NULL,
+ inhparent OID NOT NULL
+);
+
+-- postgres=# \d pg_catalog.pg_attribute;
+-- Table "pg_catalog.pg_attribute"
+-- Column | Type | Collation | Nullable | Default
+-- ----------------+-----------+-----------+----------+---------
+-- attrelid | oid | | not null |
+-- attname | name | | not null |
+-- atttypid | oid | | not null |
+-- attstattarget | integer | | not null |
+-- attlen | smallint | | not null |
+-- attnum | smallint | | not null |
+-- attndims | integer | | not null |
+-- attcacheoff | integer | | not null |
+-- atttypmod | integer | | not null |
+-- attbyval | boolean | | not null |
+-- attalign | "char" | | not null |
+-- attstorage | "char" | | not null |
+-- attcompression | "char" | | not null |
+-- attnotnull | boolean | | not null |
+-- atthasdef | boolean | | not null |
+-- atthasmissing | boolean | | not null |
+-- attidentity | "char" | | not null |
+-- attgenerated | "char" | | not null |
+-- attisdropped | boolean | | not null |
+-- attislocal | boolean | | not null |
+-- attinhcount | integer | | not null |
+-- attcollation | oid | | not null |
+-- attacl | aclitem[] | | |
+-- attoptions | text[] | C | |
+-- attfdwoptions | text[] | C | |
+-- attmissingval | anyarray | | |
+-- Indexes:
+-- "pg_attribute_relid_attnum_index" PRIMARY KEY, btree (attrelid, attnum)
+-- "pg_attribute_relid_attnam_index" UNIQUE CONSTRAINT, btree (attrelid, attname)
+
+CREATE TABLE pg_catalog.pg_attribute
+(
+ attrelid OID NOT NULL,
+ attname TEXT NOT NULL,
+ attlen SMALLINT NOT NULL,
+ attnum SMALLINT NOT NULL,
+ atttypid INTEGER NOT NULL,
+ atttypmod INTEGER NOT NULL,
+ attnotnull BOOLEAN NOT NULL,
+ attcollation OID NOT NULL
+);
+
+-- postgres=# \d pg_catalog.pg_collation;
+-- Table "pg_catalog.pg_collation"
+-- Column | Type | Collation | Nullable | Default
+-- ---------------------+---------+-----------+----------+---------
+-- oid | oid | | not null |
+-- collname | name | | not null |
+-- collnamespace | oid | | not null |
+-- collowner | oid | | not null |
+-- collprovider | "char" | | not null |
+-- collisdeterministic | boolean | | not null |
+-- collencoding | integer | | not null |
+-- collcollate | name | | not null |
+-- collctype | name | | not null |
+-- collversion | text | C | |
+-- Indexes:
+-- "pg_collation_oid_index" PRIMARY KEY, btree (oid)
+-- "pg_collation_name_enc_nsp_index" UNIQUE CONSTRAINT, btree (collname, collencoding, collnamespace)
+CREATE TABLE pg_catalog.pg_collation
+(
+ oid OID NOT NULL,
+ collname TEXT NOT NULL,
+ collnamespace OID NOT NULL
+);
+
+-- postgres=# \d pg_index;
+-- Table "pg_catalog.pg_index"
+-- Column | Type | Collation | Nullable | Default
+-- ----------------+--------------+-----------+----------+---------
+-- indexrelid | oid | | not null |
+-- indrelid | oid | | not null |
+-- indnatts | smallint | | not null |
+-- indnkeyatts | smallint | | not null |
+-- indisunique | boolean | | not null |
+-- indisprimary | boolean | | not null |
+-- indisexclusion | boolean | | not null |
+-- indimmediate | boolean | | not null |
+-- indisclustered | boolean | | not null |
+-- indisvalid | boolean | | not null |
+-- indcheckxmin | boolean | | not null |
+-- indisready | boolean | | not null |
+-- indislive | boolean | | not null |
+-- indisreplident | boolean | | not null |
+-- indkey | int2vector | | not null |
+-- indcollation | oidvector | | not null |
+-- indclass | oidvector | | not null |
+-- indoption | int2vector | | not null |
+-- indexprs | pg_node_tree | C | |
+-- indpred | pg_node_tree | C | |
+-- Indexes:
+-- "pg_index_indexrelid_index" PRIMARY KEY, btree (indexrelid)
+-- "pg_index_indrelid_index" btree (indrelid)
+CREATE TABLE pg_catalog.pg_index
+(
+ indexrelid OID NOT NULL,
+ indrelid OID NOT NULL,
+ indisunique BOOLEAN NOT NULL,
+ indisprimary BOOLEAN NOT NULL,
+ indisvalid BOOLEAN NOT NULL
+);
+
+-- postgres=# \d pg_catalog.pg_attrdef;
+-- Table "pg_catalog.pg_attrdef"
+-- Column | Type | Collation | Nullable | Default
+-- ---------+--------------+-----------+----------+---------
+-- oid | oid | | not null |
+-- adrelid | oid | | not null |
+-- adnum | smallint | | not null |
+-- adbin | pg_node_tree | C | not null |
+-- Indexes:
+-- "pg_attrdef_oid_index" PRIMARY KEY, btree (oid)
+-- "pg_attrdef_adrelid_adnum_index" UNIQUE CONSTRAINT, btree (adrelid, adnum)
+CREATE TABLE pg_catalog.pg_attrdef
+(
+ adrelid OID NOT NULL,
+ admin INT NOT NULL,
+ adbin PG_NODE_TREE NOT NULL
+);
+
+-- postgres=# \d pg_catalog.pg_depend
+-- Table "pg_catalog.pg_depend"
+-- Column | Type | Collation | Nullable | Default
+-- -------------+---------+-----------+----------+---------
+-- classid | oid | | not null |
+-- objid | oid | | not null |
+-- objsubid | integer | | not null |
+-- refclassid | oid | | not null |
+-- refobjid | oid | | not null |
+-- refobjsubid | integer | | not null |
+-- deptype | "char" | | not null |
+-- Indexes:
+-- "pg_depend_depender_index" btree (classid, objid, objsubid)
+-- "pg_depend_reference_index" btree (refclassid, refobjid, refobjsubid)
+CREATE TABLE pg_catalog.pg_depend
+(
+ objid OID,
+ refobjid OID,
+ deptype char
+);
+
+
+-- postgres=# \d pg_catalog.pg_constraint
+-- Table "pg_catalog.pg_constraint"
+-- Column | Type | Collation | Nullable | Default
+-- ---------------+--------------+-----------+----------+---------
+-- oid | oid | | not null |
+-- conname | name | | not null |
+-- connamespace | oid | | not null |
+-- contype | "char" | | not null |
+-- condeferrable | boolean | | not null |
+-- condeferred | boolean | | not null |
+-- convalidated | boolean | | not null |
+-- conrelid | oid | | not null |
+-- contypid | oid | | not null |
+-- conindid | oid | | not null |
+-- conparentid | oid | | not null |
+-- confrelid | oid | | not null |
+-- confupdtype | "char" | | not null |
+-- confdeltype | "char" | | not null |
+-- confmatchtype | "char" | | not null |
+-- conislocal | boolean | | not null |
+-- coninhcount | integer | | not null |
+-- connoinherit | boolean | | not null |
+-- conkey | smallint[] | | |
+-- confkey | smallint[] | | |
+-- conpfeqop | oid[] | | |
+-- conppeqop | oid[] | | |
+-- conffeqop | oid[] | | |
+-- conexclop | oid[] | | |
+-- conbin | pg_node_tree | C | |
+-- Indexes:
+-- "pg_constraint_oid_index" PRIMARY KEY, btree (oid)
+-- "pg_constraint_conname_nsp_index" btree (conname, connamespace)
+-- "pg_constraint_conparentid_index" btree (conparentid)
+-- "pg_constraint_conrelid_contypid_conname_index" UNIQUE CONSTRAINT, btree (conrelid, contypid, conname)
+-- "pg_constraint_contypid_index" btree (contypid)
+CREATE TABLE pg_catalog.pg_constraint
+(
+ oid OID NOT NULL,
+ conname TEXT NOT NULL,
+ conindid OID NOT NULL,
+ contype CHAR NOT NULL,
+ condeferrable BOOLEAN NOT NULL,
+ condeferred BOOLEAN NOT NULL,
+ convalidated BOOLEAN NOT NULL,
+ conrelid OID NOT NULL,
+ conislocal BOOLEAN NOT NULL,
+ connoinherit BOOLEAN NOT NULL,
+ conbin PG_NODE_TREE NOT NULL
+);
+
+-- postgres=# \d pg_catalog.pg_proc
+-- Table "pg_catalog.pg_proc"
+-- Column | Type | Collation | Nullable | Default
+-- -----------------+--------------+-----------+----------+---------
+-- oid | oid | | not null |
+-- proname | name | | not null |
+-- pronamespace | oid | | not null |
+-- proowner | oid | | not null |
+-- prolang | oid | | not null |
+-- procost | real | | not null |
+-- prorows | real | | not null |
+-- provariadic | oid | | not null |
+-- prosupport | regproc | | not null |
+-- prokind | "char" | | not null |
+-- prosecdef | boolean | | not null |
+-- proleakproof | boolean | | not null |
+-- proisstrict | boolean | | not null |
+-- proretset | boolean | | not null |
+-- provolatile | "char" | | not null |
+-- proparallel | "char" | | not null |
+-- pronargs | smallint | | not null |
+-- pronargdefaults | smallint | | not null |
+-- prorettype | oid | | not null |
+-- proargtypes | oidvector | | not null |
+-- proallargtypes | oid[] | | |
+-- proargmodes | "char"[] | | |
+-- proargnames | text[] | C | |
+-- proargdefaults | pg_node_tree | C | |
+-- protrftypes | oid[] | | |
+-- prosrc | text | C | not null |
+-- probin | text | C | |
+-- prosqlbody | pg_node_tree | C | |
+-- proconfig | text[] | C | |
+-- proacl | aclitem[] | | |
+-- Indexes:
+-- "pg_proc_oid_index" PRIMARY KEY, btree (oid)
+-- "pg_proc_proname_args_nsp_index" UNIQUE CONSTRAINT, btree (proname, proargtypes, pronamespace)
+CREATE TABLE pg_catalog.pg_proc
+(
+ oid OID NOT NULL,
+ proname TEXT NOT NULL,
+ pronamespace OID NOT NULL,
+ prolang OID NOT NULL,
+ prokind CHAR NOT NULL
+);
+
+
+-- postgres=# \d pg_catalog.pg_language
+-- Table "pg_catalog.pg_language"
+-- Column | Type | Collation | Nullable | Default
+-- ---------------+-----------+-----------+----------+---------
+-- oid | oid | | not null |
+-- lanname | name | | not null |
+-- lanowner | oid | | not null |
+-- lanispl | boolean | | not null |
+-- lanpltrusted | boolean | | not null |
+-- lanplcallfoid | oid | | not null |
+-- laninline | oid | | not null |
+-- lanvalidator | oid | | not null |
+-- lanacl | aclitem[] | | |
+-- Indexes:
+-- "pg_language_oid_index" PRIMARY KEY, btree (oid)
+-- "pg_language_name_index" UNIQUE CONSTRAINT, btree (lanname)
+CREATE TABLE pg_catalog.pg_language
+(
+ oid OID NOT NULL,
+ lanname TEXT NOT NULL
+);
+
+
+-- postgres=# \d pg_catalog.pg_trigger;
+-- Table "pg_catalog.pg_trigger"
+-- Column | Type | Collation | Nullable | Default
+-- ----------------+--------------+-----------+----------+---------
+-- oid | oid | | not null |
+-- tgrelid | oid | | not null |
+-- tgparentid | oid | | not null |
+-- tgname | name | | not null |
+-- tgfoid | oid | | not null |
+-- tgtype | smallint | | not null |
+-- tgenabled | "char" | | not null |
+-- tgisinternal | boolean | | not null |
+-- tgconstrrelid | oid | | not null |
+-- tgconstrindid | oid | | not null |
+-- tgconstraint | oid | | not null |
+-- tgdeferrable | boolean | | not null |
+-- tginitdeferred | boolean | | not null |
+-- tgnargs | smallint | | not null |
+-- tgattr | int2vector | | not null |
+-- tgargs | bytea | | not null |
+-- tgqual | pg_node_tree | C | |
+-- tgoldtable | name | | |
+-- tgnewtable | name | | |
+-- Indexes:
+-- "pg_trigger_oid_index" PRIMARY KEY, btree (oid)
+-- "pg_trigger_tgrelid_tgname_index" UNIQUE CONSTRAINT, btree (tgrelid, tgname)
+-- "pg_trigger_tgconstraint_index" btree (tgconstraint)
+CREATE TABLE pg_catalog.pg_trigger
+(
+ oid OID NOT NULL,
+ tgrelid OID NOT NULL,
+ tgparentid OID NOT NULL,
+ tgname TEXT NOT NULL,
+ tfoid OID NOT NULL,
+ tgisinternal BOOLEAN NOT NULL
+);
diff --git a/internal/schema/schema.go b/internal/schema/schema.go
new file mode 100644
index 0000000..131cfc0
--- /dev/null
+++ b/internal/schema/schema.go
@@ -0,0 +1,500 @@
+package schema
+
+import (
+ "context"
+ "fmt"
+ "regexp"
+ "sort"
+
+ "github.com/mitchellh/hashstructure/v2"
+ "github.com/stripe/pg-schema-diff/internal/queries"
+)
+
+type (
+ // Object represents a resource in a schema (table, column, index...)
+ Object interface {
+ // GetName is used to identify the old and new versions of a schema object between the old and new schemas
+ // If the name is not present in the old schema objects list, then it is added
+ // If the name is not present in the new schemas objects list, then it is removed
+ // Otherwise, it has persisted across two schemas and is possibly altered
+ GetName() string
+ }
+
+ // SchemaQualifiedName represents a schema object name scoped within a schema
+ SchemaQualifiedName struct {
+ SchemaName string
+ // EscapedName is the name of the object. It should already be escaped
+ // We take an escaped name because there are weird exceptions, like functions, where we can't just
+ // surround the name in quotes
+ EscapedName string
+ }
+)
+
+func (o SchemaQualifiedName) GetName() string {
+ return o.GetFQEscapedName()
+}
+
+// GetFQEscapedName gets the fully-qualified, escaped name of the schema object, including the schema name
+func (o SchemaQualifiedName) GetFQEscapedName() string {
+ return fmt.Sprintf("%s.%s", EscapeIdentifier(o.SchemaName), o.EscapedName)
+}
+
+func (o SchemaQualifiedName) IsEmpty() bool {
+ return len(o.SchemaName) == 0
+}
+
+type Schema struct {
+ // Name refers to the name of the schema. Ultimately, schema objects can cut across
+ // schemas, e.g., a partition of a table can exist in a different table. Thus, we're probably
+ // going to delete this Name attribute soon, once multi-schema is supported
+ Name string
+ Tables []Table
+ Indexes []Index
+
+ Functions []Function
+ Triggers []Trigger
+}
+
+func (s Schema) GetName() string {
+ return s.Name
+}
+
+// Normalize normalizes the schema (alphabetically sorts tables and columns in tables)
+// Useful for hashing and testing
+func (s Schema) Normalize() Schema {
+ var normTables []Table
+ for _, table := range sortSchemaObjectsByName(s.Tables) {
+ // Don't normalize columns order. their order is derived from the postgres catalogs
+ // (relevant to data packing)
+ var normCheckConstraints []CheckConstraint
+ for _, checkConstraint := range sortSchemaObjectsByName(table.CheckConstraints) {
+ checkConstraint.DependsOnFunctions = sortSchemaObjectsByName(checkConstraint.DependsOnFunctions)
+ normCheckConstraints = append(normCheckConstraints, checkConstraint)
+ }
+ table.CheckConstraints = normCheckConstraints
+ normTables = append(normTables, table)
+ }
+ s.Tables = normTables
+
+ s.Indexes = sortSchemaObjectsByName(s.Indexes)
+
+ var normFunctions []Function
+ for _, function := range sortSchemaObjectsByName(s.Functions) {
+ function.DependsOnFunctions = sortSchemaObjectsByName(function.DependsOnFunctions)
+ normFunctions = append(normFunctions, function)
+ }
+ s.Functions = normFunctions
+
+ s.Triggers = sortSchemaObjectsByName(s.Triggers)
+
+ return s
+}
+
+// sortSchemaObjectsByName returns a (copied) sorted list of schema objects.
+func sortSchemaObjectsByName[S Object](vals []S) []S {
+ clonedVals := make([]S, len(vals))
+ copy(clonedVals, vals)
+ sort.Slice(clonedVals, func(i, j int) bool {
+ return clonedVals[i].GetName() < clonedVals[j].GetName()
+ })
+ return clonedVals
+}
+
+func (s Schema) Hash() (string, error) {
+ // alternatively, we can print the struct as a string and hash it
+ hashVal, err := hashstructure.Hash(s.Normalize(), hashstructure.FormatV2, nil)
+ if err != nil {
+ return "", fmt.Errorf("hashing schema: %w", err)
+ }
+ return fmt.Sprintf("%x", hashVal), nil
+}
+
+type Table struct {
+ Name string
+ Columns []Column
+ CheckConstraints []CheckConstraint
+
+ // PartitionKeyDef is the output of Pg function pg_get_partkeydef:
+ // PARTITION BY $PartitionKeyDef
+ // If empty, then the table is not partitioned
+ PartitionKeyDef string
+
+ ParentTableName string
+ ForValues string
+}
+
+func (t Table) IsPartitioned() bool {
+ return len(t.PartitionKeyDef) > 0
+}
+
+func (t Table) IsPartition() bool {
+ return len(t.ForValues) > 0
+}
+
+func (t Table) GetName() string {
+ return t.Name
+}
+
+type Column struct {
+ Name string
+ Type string
+ Collation SchemaQualifiedName
+ // If the column has a default value, this will be a SQL string representing that value.
+ // Examples:
+ // ''::text
+ // CURRENT_TIMESTAMP
+ // If empty, indicates that there is no default value.
+ Default string
+ IsNullable bool
+
+ // Size is the number of bytes required to store the value.
+ // It is used for data-packing purposes
+ Size int //
+}
+
+func (c Column) GetName() string {
+ return c.Name
+}
+
+func (c Column) IsCollated() bool {
+ return !c.Collation.IsEmpty()
+}
+
+var (
+ // The first matching group is the "CREATE [UNIQUE] INDEX ". UNIQUE is an optional match
+ // because only UNIQUE indices will have the UNIQUE keyword in their pg_get_indexdef statement
+ //
+ // The third matching group is the rest of the statement
+ idxToConcurrentlyRegex = regexp.MustCompile("^(CREATE (UNIQUE )?INDEX )(.*)$")
+)
+
+// GetIndexDefStatement is the output of pg_getindexdef. It is a `CREATE INDEX` statement that will re-create
+// the index. This statement does not contain `CONCURRENTLY`.
+// For unique indexes, it does contain `UNIQUE`
+// For partitioned tables, it does contain `ONLY`
+type GetIndexDefStatement string
+
+func (i GetIndexDefStatement) ToCreateIndexConcurrently() (string, error) {
+ if !idxToConcurrentlyRegex.MatchString(string(i)) {
+ return "", fmt.Errorf("%s follows an unexpected structure", i)
+ }
+ return idxToConcurrentlyRegex.ReplaceAllString(string(i), "${1}CONCURRENTLY ${3}"), nil
+}
+
+type Index struct {
+ TableName string
+ Name string
+ Columns []string
+ IsInvalid bool
+ IsPk bool
+ IsUnique bool
+ // ConstraintName is the name of the constraint associated with an index. Empty string if no associated constraint.
+ // Once we need support for constraints not associated with indexes, we'll add a
+ // Constraint schema object and starting fetching constraints directly
+ ConstraintName string
+
+ // GetIndexDefStmt is the output of pg_getindexdef
+ GetIndexDefStmt GetIndexDefStatement
+
+ // ParentIdxName is the name of the parent index if the index is a partition of an index
+ ParentIdxName string
+}
+
+func (i Index) GetName() string {
+ return i.Name
+}
+
+func (i Index) IsPartitionOfIndex() bool {
+ return len(i.ParentIdxName) > 0
+}
+
+type CheckConstraint struct {
+ Name string
+ Expression string
+ IsValid bool
+ IsInheritable bool
+ DependsOnFunctions []SchemaQualifiedName
+}
+
+func (c CheckConstraint) GetName() string {
+ return c.Name
+}
+
+type Function struct {
+ SchemaQualifiedName
+ // FunctionDef is the statement required to completely (re)create
+ // the function, as returned by `pg_get_functiondef`. It is a CREATE OR REPLACE
+ // statement
+ FunctionDef string
+ // Language is the language of the function. This is relevant in determining if we
+ // can track the dependencies of the function (or not)
+ Language string
+ DependsOnFunctions []SchemaQualifiedName
+}
+
+var (
+ // The first matching group is the "CREATE ". The second matching group is the rest of the statement
+ triggerToOrReplaceRegex = regexp.MustCompile("^(CREATE )(.*)$")
+)
+
+// GetTriggerDefStatement is the output of pg_get_triggerdef. It is a `CREATE TRIGGER` statement that will create
+// the trigger. This statement does not contain `OR REPLACE`
+type GetTriggerDefStatement string
+
+func (g GetTriggerDefStatement) ToCreateOrReplace() (string, error) {
+ if !triggerToOrReplaceRegex.MatchString(string(g)) {
+ return "", fmt.Errorf("%s follows an unexpected structure", g)
+ }
+ return triggerToOrReplaceRegex.ReplaceAllString(string(g), "${1}OR REPLACE ${2}"), nil
+}
+
+type Trigger struct {
+ EscapedName string
+ OwningTable SchemaQualifiedName
+ // OwningTableUnescapedName lets us be backwards compatible with the TableSQLVertexGenerator, which
+ // currently uses the unescaped name as the vertex id. This will be removed once the TableSQLVertexGenerator
+ // is migrated to use SchemaQualifiedName
+ OwningTableUnescapedName string
+ Function SchemaQualifiedName
+ // GetTriggerDefStmt is the statement required to completely (re)create the trigger, as returned
+ // by pg_get_triggerdef
+ GetTriggerDefStmt GetTriggerDefStatement
+}
+
+func (t Trigger) GetName() string {
+ return t.OwningTable.GetFQEscapedName() + "_" + t.EscapedName
+}
+
+// GetPublicSchema fetches the "public" schema. It is a non-atomic operation
+func GetPublicSchema(ctx context.Context, db queries.DBTX) (Schema, error) {
+ q := queries.New(db)
+
+ tables, err := fetchTables(ctx, q)
+ if err != nil {
+ return Schema{}, fmt.Errorf("fetchTables: %w", err)
+ }
+
+ indexes, err := fetchIndexes(ctx, q)
+ if err != nil {
+ return Schema{}, fmt.Errorf("fetchIndexes: %w", err)
+ }
+
+ functions, err := fetchFunctions(ctx, q)
+ if err != nil {
+ return Schema{}, fmt.Errorf("fetchFunctions: %w", err)
+ }
+
+ triggers, err := fetchTriggers(ctx, q)
+ if err != nil {
+ return Schema{}, fmt.Errorf("fetchTriggers: %w", err)
+ }
+
+ return Schema{
+ Name: "public",
+ Tables: tables,
+ Indexes: indexes,
+ Functions: functions,
+ Triggers: triggers,
+ }, nil
+}
+
+func fetchTables(ctx context.Context, q *queries.Queries) ([]Table, error) {
+ rawTables, err := q.GetTables(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("GetTables(): %w", err)
+ }
+
+ tablesToCheckConsMap, err := fetchCheckConsAndBuildTableToCheckConsMap(ctx, q)
+ if err != nil {
+ return nil, fmt.Errorf("fetchCheckConsAndBuildTableToCheckConsMap: %w", err)
+ }
+
+ var tables []Table
+ for _, table := range rawTables {
+ if len(table.ParentTableName) > 0 && table.ParentTableSchemaName != "public" {
+ return nil, fmt.Errorf(
+ "table %s has parent table in schema %s. only parent tables in public schema are supported: %w",
+ table.TableName,
+ table.ParentTableSchemaName,
+ err,
+ )
+ }
+
+ rawColumns, err := q.GetColumnsForTable(ctx, table.Oid)
+ if err != nil {
+ return nil, fmt.Errorf("GetColumnsForTable(%s): %w", table.Oid, err)
+ }
+ var columns []Column
+ for _, column := range rawColumns {
+ collation := SchemaQualifiedName{}
+ if len(column.CollationName) > 0 {
+ collation = SchemaQualifiedName{
+ EscapedName: EscapeIdentifier(column.CollationName),
+ SchemaName: column.CollationSchemaName,
+ }
+ }
+
+ columns = append(columns, Column{
+ Name: column.ColumnName,
+ Type: column.ColumnType,
+ Collation: collation,
+ IsNullable: !column.IsNotNull,
+ // If the column has a default value, this will be a SQL string representing that value.
+ // Examples:
+ // ''::text
+ // CURRENT_TIMESTAMP
+ // If empty, indicates that there is no default value.
+ Default: column.DefaultValue,
+ Size: int(column.ColumnSize),
+ })
+ }
+
+ tables = append(tables, Table{
+ Name: table.TableName,
+ Columns: columns,
+ CheckConstraints: tablesToCheckConsMap[table.TableName],
+
+ PartitionKeyDef: table.PartitionKeyDef,
+
+ ParentTableName: table.ParentTableName,
+ ForValues: table.PartitionForValues,
+ })
+ }
+ return tables, nil
+}
+
+// fetchCheckConsAndBuildTableToCheckConsMap fetches the check constraints and builds a map of table name to the check
+// constraints within the table
+func fetchCheckConsAndBuildTableToCheckConsMap(ctx context.Context, q *queries.Queries) (map[string][]CheckConstraint, error) {
+ rawCheckCons, err := q.GetCheckConstraints(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("GetCheckConstraints: %w", err)
+ }
+
+ result := make(map[string][]CheckConstraint)
+ for _, cc := range rawCheckCons {
+ dependsOnFunctions, err := fetchDependsOnFunctions(ctx, q, cc.Oid)
+ if err != nil {
+ return nil, fmt.Errorf("fetchDependsOnFunctions(%s): %w", cc.Oid, err)
+ }
+
+ checkCon := CheckConstraint{
+ Name: cc.Name,
+ Expression: cc.Expression,
+ IsValid: cc.IsValid,
+ IsInheritable: !cc.IsNotInheritable,
+ DependsOnFunctions: dependsOnFunctions,
+ }
+ result[cc.TableName] = append(result[cc.TableName], checkCon)
+ }
+
+ return result, nil
+}
+
+// fetchIndexes fetches the indexes We fetch all indexes at once to minimize number of queries, since each index needs
+// to fetch columns
+func fetchIndexes(ctx context.Context, q *queries.Queries) ([]Index, error) {
+ rawIndexes, err := q.GetIndexes(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("GetColumnsInPublicSchema: %w", err)
+ }
+
+ var indexes []Index
+ for _, rawIndex := range rawIndexes {
+ rawColumns, err := q.GetColumnsForIndex(ctx, rawIndex.Oid)
+ if err != nil {
+ return nil, fmt.Errorf("GetColumnsForIndex(%s): %w", rawIndex.Oid, err)
+ }
+
+ indexes = append(indexes, Index{
+ TableName: rawIndex.TableName,
+ Name: rawIndex.IndexName,
+ Columns: rawColumns,
+ GetIndexDefStmt: GetIndexDefStatement(rawIndex.DefStmt),
+ IsInvalid: !rawIndex.IndexIsValid,
+ IsPk: rawIndex.IndexIsPk,
+ IsUnique: rawIndex.IndexIsUnique,
+ ConstraintName: rawIndex.ConstraintName,
+ ParentIdxName: rawIndex.ParentIndexName,
+ })
+ }
+
+ return indexes, nil
+}
+
+// fetchFunctions fetches the functions required to
+func fetchFunctions(ctx context.Context, q *queries.Queries) ([]Function, error) {
+ rawFunctions, err := q.GetFunctions(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("GetFunctions: %w", err)
+ }
+
+ var functions []Function
+ for _, rawFunction := range rawFunctions {
+ dependsOnFunctions, err := fetchDependsOnFunctions(ctx, q, rawFunction.Oid)
+ if err != nil {
+ return nil, fmt.Errorf("fetchDependsOnFunctions(%s): %w", rawFunction.Oid, err)
+ }
+
+ functions = append(functions, Function{
+ SchemaQualifiedName: buildFuncName(rawFunction.FuncName, rawFunction.FuncIdentityArguments, rawFunction.FuncSchemaName),
+ FunctionDef: rawFunction.FuncDef,
+ Language: rawFunction.FuncLang,
+ DependsOnFunctions: dependsOnFunctions,
+ })
+ }
+
+ return functions, nil
+}
+
+func fetchDependsOnFunctions(ctx context.Context, q *queries.Queries, oid any) ([]SchemaQualifiedName, error) {
+ dependsOnFunctions, err := q.GetDependsOnFunctions(ctx, oid)
+ if err != nil {
+ return nil, err
+ }
+
+ var functionNames []SchemaQualifiedName
+ for _, rawFunction := range dependsOnFunctions {
+ functionNames = append(functionNames, buildFuncName(rawFunction.FuncName, rawFunction.FuncIdentityArguments, rawFunction.FuncSchemaName))
+ }
+
+ return functionNames, nil
+}
+
+func fetchTriggers(ctx context.Context, q *queries.Queries) ([]Trigger, error) {
+ rawTriggers, err := q.GetTriggers(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("GetTriggers: %w", err)
+ }
+
+ var triggers []Trigger
+ for _, rawTrigger := range rawTriggers {
+ triggers = append(triggers, Trigger{
+ EscapedName: EscapeIdentifier(rawTrigger.TriggerName),
+ OwningTable: buildNameFromUnescaped(rawTrigger.OwningTableName, rawTrigger.OwningTableSchemaName),
+ OwningTableUnescapedName: rawTrigger.OwningTableName,
+ Function: buildFuncName(rawTrigger.FuncName, rawTrigger.FuncIdentityArguments, rawTrigger.FuncSchemaName),
+ GetTriggerDefStmt: GetTriggerDefStatement(rawTrigger.TriggerDef),
+ })
+ }
+
+ return triggers, nil
+}
+
+func buildFuncName(name, identityArguments, schemaName string) SchemaQualifiedName {
+ return SchemaQualifiedName{
+ SchemaName: schemaName,
+ EscapedName: fmt.Sprintf("\"%s\"(%s)", name, identityArguments),
+ }
+}
+
+func buildNameFromUnescaped(unescapedName, schemaName string) SchemaQualifiedName {
+ return SchemaQualifiedName{
+ EscapedName: EscapeIdentifier(unescapedName),
+ SchemaName: schemaName,
+ }
+}
+
+func EscapeIdentifier(name string) string {
+ return fmt.Sprintf("\"%s\"", name)
+}
diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go
new file mode 100644
index 0000000..d025cad
--- /dev/null
+++ b/internal/schema/schema_test.go
@@ -0,0 +1,886 @@
+package schema_test
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+
+ _ "github.com/jackc/pgx/v4/stdlib"
+ "github.com/kr/pretty"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/stripe/pg-schema-diff/internal/pgengine"
+ "github.com/stripe/pg-schema-diff/internal/schema"
+)
+
+type testCase struct {
+ name string
+ ddl []string
+ expectedSchema schema.Schema
+ expectedHash string
+ expectedErrIs error
+}
+
+var (
+ defaultCollation = schema.SchemaQualifiedName{
+ EscapedName: `"default"`,
+ SchemaName: "pg_catalog",
+ }
+ cCollation = schema.SchemaQualifiedName{
+ EscapedName: `"C"`,
+ SchemaName: "pg_catalog",
+ }
+
+ testCases = []*testCase{
+ {
+ name: "Simple test",
+ ddl: []string{`
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION increment(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + increment(a);
+
+ CREATE TABLE foo (
+ id INTEGER PRIMARY KEY,
+ author TEXT COLLATE "C",
+ content TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP CHECK (created_at > CURRENT_TIMESTAMP - interval '1 month') NO INHERIT,
+ version INT NOT NULL DEFAULT 0,
+ CHECK ( function_with_dependencies(id, id) > 0)
+ );
+
+ ALTER TABLE foo ADD CONSTRAINT author_check CHECK (author IS NOT NULL AND LENGTH(author) > 0) NO INHERIT NOT VALID;
+ CREATE INDEX some_idx ON foo USING hash (content);
+ CREATE UNIQUE INDEX some_unique_idx ON foo (created_at DESC, author ASC);
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_trigger
+ BEFORE UPDATE ON foo
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+ `},
+ expectedHash: "9648c294aed76ef6",
+ expectedSchema: schema.Schema{
+ Name: "public",
+ Tables: []schema.Table{
+ {
+ Name: "foo",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Size: 4},
+ {Name: "author", Type: "text", IsNullable: true, Size: -1, Collation: cCollation},
+ {Name: "content", Type: "text", Default: "''::text", Size: -1, Collation: defaultCollation},
+ {Name: "created_at", Type: "timestamp without time zone", Default: "CURRENT_TIMESTAMP", Size: 8},
+ {Name: "version", Type: "integer", Default: "0", Size: 4},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "author_check", Expression: "((author IS NOT NULL) AND (length(author) > 0))"},
+ {Name: "foo_created_at_check", Expression: "(created_at > (CURRENT_TIMESTAMP - '1 mon'::interval))", IsValid: true},
+ {
+ Name: "foo_id_check",
+ Expression: "(function_with_dependencies(id, id) > 0)",
+ IsValid: true,
+ IsInheritable: true,
+ DependsOnFunctions: []schema.SchemaQualifiedName{
+ {EscapedName: "\"function_with_dependencies\"(a integer, b integer)", SchemaName: "public"},
+ },
+ },
+ },
+ },
+ },
+ Indexes: []schema.Index{
+ {
+ TableName: "foo",
+ Name: "foo_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON public.foo USING btree (id)",
+ },
+ {
+ TableName: "foo",
+ Name: "some_idx", Columns: []string{"content"},
+ GetIndexDefStmt: "CREATE INDEX some_idx ON public.foo USING hash (content)",
+ },
+ {
+ TableName: "foo",
+ Name: "some_unique_idx", Columns: []string{"created_at", "author"}, IsPk: false, IsUnique: true,
+ GetIndexDefStmt: "CREATE UNIQUE INDEX some_unique_idx ON public.foo USING btree (created_at DESC, author)",
+ },
+ },
+ Functions: []schema.Function{
+ {
+ SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"add\"(a integer, b integer)", SchemaName: "public"},
+ FunctionDef: "CREATE OR REPLACE FUNCTION public.add(a integer, b integer)\n RETURNS integer\n LANGUAGE sql\n IMMUTABLE STRICT\nRETURN (a + b)\n",
+ Language: "sql",
+ },
+ {
+ SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"function_with_dependencies\"(a integer, b integer)", SchemaName: "public"},
+ FunctionDef: "CREATE OR REPLACE FUNCTION public.function_with_dependencies(a integer, b integer)\n RETURNS integer\n LANGUAGE sql\n IMMUTABLE STRICT\nRETURN (add(a, b) + increment(a))\n",
+ Language: "sql",
+ DependsOnFunctions: []schema.SchemaQualifiedName{
+ {EscapedName: "\"add\"(a integer, b integer)", SchemaName: "public"},
+ {EscapedName: "\"increment\"(i integer)", SchemaName: "public"},
+ },
+ },
+ {
+ SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"increment\"(i integer)", SchemaName: "public"},
+ FunctionDef: "CREATE OR REPLACE FUNCTION public.increment(i integer)\n RETURNS integer\n LANGUAGE plpgsql\nAS $function$\n\t\t\t\t\tBEGIN\n\t\t\t\t\t\t\tRETURN i + 1;\n\t\t\t\t\tEND;\n\t\t\t$function$\n",
+ Language: "plpgsql",
+ },
+ {
+ SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"},
+ FunctionDef: "CREATE OR REPLACE FUNCTION public.increment_version()\n RETURNS trigger\n LANGUAGE plpgsql\nAS $function$\n\t\t\t\tBEGIN\n\t\t\t\t\tNEW.version = OLD.version + 1;\n\t\t\t\t\tRETURN NEW;\n\t\t\t\tEND;\n\t\t\t$function$\n",
+ Language: "plpgsql",
+ },
+ },
+ Triggers: []schema.Trigger{
+ {
+ EscapedName: "\"some_trigger\"",
+ OwningTable: schema.SchemaQualifiedName{EscapedName: "\"foo\"", SchemaName: "public"},
+ OwningTableUnescapedName: "foo",
+ Function: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"},
+ GetTriggerDefStmt: "CREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()",
+ },
+ },
+ },
+ },
+ {
+ name: "Simple partition test",
+ ddl: []string{`
+ CREATE TABLE foo (
+ id INTEGER CHECK (id > 0),
+ author TEXT COLLATE "C",
+ content TEXT DEFAULT '',
+ genre VARCHAR(256) NOT NULL,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP CHECK (created_at > CURRENT_TIMESTAMP - interval '1 month'),
+ PRIMARY KEY (author, id)
+ ) PARTITION BY LIST (author);
+ ALTER TABLE foo ADD CONSTRAINT author_check CHECK (author IS NOT NULL AND LENGTH(author) > 0) NOT VALID;
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_trigger
+ BEFORE UPDATE ON foo
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+
+ CREATE TABLE foo_1 PARTITION OF foo(
+ content NOT NULL
+ ) FOR VALUES IN ('some author 1');
+ CREATE TABLE foo_2 PARTITION OF foo FOR VALUES IN ('some author 2');
+ CREATE TABLE foo_3 PARTITION OF foo FOR VALUES IN ('some author 3');
+
+ -- partitioned indexes
+ CREATE INDEX some_partitioned_idx ON foo USING hash(author);
+ CREATE UNIQUE INDEX some_unique_partitioned_idx ON foo(author, created_at DESC);
+ CREATE INDEX some_invalid_idx ON ONLY foo(author, genre);
+
+ -- local indexes
+ CREATE UNIQUE INDEX foo_1_local_idx ON foo_1(author DESC, id);
+ CREATE UNIQUE INDEX foo_2_local_idx ON foo_2(author, content);
+ CREATE UNIQUE INDEX foo_3_local_idx ON foo_3(author, created_at);
+
+ CREATE TRIGGER some_partition_trigger
+ BEFORE UPDATE ON foo_1
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+
+ `},
+ expectedHash: "651c9229cd8373f0",
+ expectedSchema: schema.Schema{
+ Name: "public",
+ Tables: []schema.Table{
+ {
+ Name: "foo",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Size: 4},
+ {Name: "author", Type: "text", Size: -1, Collation: cCollation},
+ {Name: "content", Type: "text", Default: "''::text", IsNullable: true, Size: -1, Collation: defaultCollation},
+ {Name: "genre", Type: "character varying(256)", Size: -1, Collation: defaultCollation},
+ {Name: "created_at", Type: "timestamp without time zone", Default: "CURRENT_TIMESTAMP", Size: 8},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "author_check", Expression: "((author IS NOT NULL) AND (length(author) > 0))", IsInheritable: true},
+ {Name: "foo_created_at_check", Expression: "(created_at > (CURRENT_TIMESTAMP - '1 mon'::interval))", IsValid: true, IsInheritable: true},
+ {Name: "foo_id_check", Expression: "(id > 0)", IsValid: true, IsInheritable: true},
+ },
+ PartitionKeyDef: "LIST (author)",
+ },
+ {
+ ParentTableName: "foo",
+ Name: "foo_1",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Size: 4},
+ {Name: "author", Type: "text", Size: -1, Collation: cCollation},
+ {Name: "content", Type: "text", Default: "''::text", Size: -1, Collation: defaultCollation},
+ {Name: "genre", Type: "character varying(256)", Size: -1, Collation: defaultCollation},
+ {Name: "created_at", Type: "timestamp without time zone", Default: "CURRENT_TIMESTAMP", Size: 8},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some author 1')",
+ },
+ {
+ ParentTableName: "foo",
+ Name: "foo_2",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Size: 4},
+ {Name: "author", Type: "text", Size: -1, Collation: cCollation},
+ {Name: "content", Type: "text", Default: "''::text", IsNullable: true, Size: -1, Collation: defaultCollation},
+ {Name: "genre", Type: "character varying(256)", Size: -1, Collation: defaultCollation},
+ {Name: "created_at", Type: "timestamp without time zone", Default: "CURRENT_TIMESTAMP", Size: 8},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some author 2')",
+ },
+ {
+ ParentTableName: "foo",
+ Name: "foo_3",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Size: 4},
+ {Name: "author", Type: "text", Size: -1, Collation: cCollation},
+ {Name: "content", Type: "text", Default: "''::text", IsNullable: true, Size: -1, Collation: defaultCollation},
+ {Name: "genre", Type: "character varying(256)", Size: -1, Collation: defaultCollation},
+ {Name: "created_at", Type: "timestamp without time zone", Default: "CURRENT_TIMESTAMP", Size: 8},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some author 3')",
+ },
+ },
+ Indexes: []schema.Index{
+ {
+ TableName: "foo",
+ Name: "foo_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON ONLY public.foo USING btree (author, id)",
+ },
+ {
+ TableName: "foo",
+ Name: "some_partitioned_idx", Columns: []string{"author"},
+ GetIndexDefStmt: "CREATE INDEX some_partitioned_idx ON ONLY public.foo USING hash (author)",
+ },
+ {
+ TableName: "foo",
+ Name: "some_unique_partitioned_idx", Columns: []string{"author", "created_at"}, IsPk: false, IsUnique: true,
+ GetIndexDefStmt: "CREATE UNIQUE INDEX some_unique_partitioned_idx ON ONLY public.foo USING btree (author, created_at DESC)",
+ },
+ {
+ TableName: "foo",
+ Name: "some_invalid_idx", Columns: []string{"author", "genre"}, IsInvalid: true, IsPk: false, IsUnique: false,
+ GetIndexDefStmt: "CREATE INDEX some_invalid_idx ON ONLY public.foo USING btree (author, genre)",
+ },
+ // foo_1 indexes
+ {
+ TableName: "foo_1",
+ Name: "foo_1_author_idx", Columns: []string{"author"}, ParentIdxName: "some_partitioned_idx",
+ GetIndexDefStmt: "CREATE INDEX foo_1_author_idx ON public.foo_1 USING hash (author)",
+ },
+ {
+ TableName: "foo_1",
+ Name: "foo_1_author_created_at_idx", Columns: []string{"author", "created_at"}, IsPk: false, IsUnique: true, ParentIdxName: "some_unique_partitioned_idx",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_1_author_created_at_idx ON public.foo_1 USING btree (author, created_at DESC)",
+ },
+ {
+ TableName: "foo_1",
+ Name: "foo_1_local_idx", Columns: []string{"author", "id"}, IsUnique: true,
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_1_local_idx ON public.foo_1 USING btree (author DESC, id)",
+ },
+ {
+ TableName: "foo_1",
+ Name: "foo_1_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_1_pkey", ParentIdxName: "foo_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_1_pkey ON public.foo_1 USING btree (author, id)",
+ },
+ // foo_2 indexes
+ {
+ TableName: "foo_2",
+ Name: "foo_2_author_idx", Columns: []string{"author"}, ParentIdxName: "some_partitioned_idx",
+ GetIndexDefStmt: "CREATE INDEX foo_2_author_idx ON public.foo_2 USING hash (author)",
+ },
+ {
+ TableName: "foo_2",
+ Name: "foo_2_author_created_at_idx", Columns: []string{"author", "created_at"}, IsPk: false, IsUnique: true, ParentIdxName: "some_unique_partitioned_idx",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_2_author_created_at_idx ON public.foo_2 USING btree (author, created_at DESC)",
+ },
+ {
+ TableName: "foo_2",
+ Name: "foo_2_local_idx", Columns: []string{"author", "content"}, IsUnique: true,
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_2_local_idx ON public.foo_2 USING btree (author, content)",
+ },
+ {
+ TableName: "foo_2",
+ Name: "foo_2_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_2_pkey", ParentIdxName: "foo_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_2_pkey ON public.foo_2 USING btree (author, id)",
+ },
+ // foo_3 indexes
+ {
+ TableName: "foo_3",
+ Name: "foo_3_author_idx", Columns: []string{"author"}, ParentIdxName: "some_partitioned_idx",
+ GetIndexDefStmt: "CREATE INDEX foo_3_author_idx ON public.foo_3 USING hash (author)",
+ },
+ {
+ TableName: "foo_3",
+ Name: "foo_3_author_created_at_idx", Columns: []string{"author", "created_at"}, IsPk: false, IsUnique: true, ParentIdxName: "some_unique_partitioned_idx",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_3_author_created_at_idx ON public.foo_3 USING btree (author, created_at DESC)",
+ },
+ {
+ TableName: "foo_3",
+ Name: "foo_3_local_idx", Columns: []string{"author", "created_at"}, IsUnique: true,
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_3_local_idx ON public.foo_3 USING btree (author, created_at)",
+ },
+ {
+ TableName: "foo_3",
+ Name: "foo_3_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_3_pkey", ParentIdxName: "foo_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_3_pkey ON public.foo_3 USING btree (author, id)",
+ },
+ },
+
+ Functions: []schema.Function{
+ {
+ SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"},
+ FunctionDef: "CREATE OR REPLACE FUNCTION public.increment_version()\n RETURNS trigger\n LANGUAGE plpgsql\nAS $function$\n\t\t\t\tBEGIN\n\t\t\t\t\tNEW.version = OLD.version + 1;\n\t\t\t\t\tRETURN NEW;\n\t\t\t\tEND;\n\t\t\t$function$\n",
+ Language: "plpgsql",
+ },
+ },
+ Triggers: []schema.Trigger{
+ {
+ EscapedName: "\"some_trigger\"",
+ OwningTable: schema.SchemaQualifiedName{EscapedName: "\"foo\"", SchemaName: "public"},
+ OwningTableUnescapedName: "foo",
+ Function: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"},
+ GetTriggerDefStmt: "CREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()",
+ },
+ {
+ EscapedName: "\"some_partition_trigger\"",
+ OwningTable: schema.SchemaQualifiedName{EscapedName: "\"foo_1\"", SchemaName: "public"},
+ OwningTableUnescapedName: "foo_1",
+ Function: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"},
+ GetTriggerDefStmt: "CREATE TRIGGER some_partition_trigger BEFORE UPDATE ON public.foo_1 FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()",
+ },
+ },
+ },
+ },
+ {
+ name: "Partition test local primary key",
+ ddl: []string{`
+ CREATE TABLE foo (
+ id INTEGER,
+ author TEXT
+ ) PARTITION BY LIST (author);
+
+ CREATE TABLE foo_1 PARTITION OF foo(
+ PRIMARY KEY (author, id)
+ ) FOR VALUES IN ('some author 1');
+ `},
+ expectedHash: "6976f3d0ada49b66",
+ expectedSchema: schema.Schema{
+ Name: "public",
+ Tables: []schema.Table{
+ {
+ Name: "foo",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", IsNullable: true, Size: 4},
+ {Name: "author", Type: "text", IsNullable: true, Size: -1, Collation: defaultCollation},
+ },
+ CheckConstraints: nil,
+ PartitionKeyDef: "LIST (author)",
+ },
+ {
+ ParentTableName: "foo",
+ Name: "foo_1",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Size: 4},
+ {Name: "author", Type: "text", Size: -1, Collation: defaultCollation},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some author 1')",
+ },
+ },
+ Indexes: []schema.Index{
+ {
+ TableName: "foo_1",
+ Name: "foo_1_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_1_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_1_pkey ON public.foo_1 USING btree (author, id)",
+ },
+ },
+ },
+ },
+ {
+ name: "Common Data Types",
+ ddl: []string{`
+ CREATE TABLE foo (
+ "varchar" VARCHAR(128) NOT NULL DEFAULT '',
+ "text" TEXT NOT NULL DEFAULT '',
+ "bool" BOOLEAN NOT NULL DEFAULT False,
+ "blob" BYTEA NOT NULL DEFAULT '',
+ "smallint" SMALLINT NOT NULL DEFAULT 0,
+ "real" REAL NOT NULL DEFAULT 0.0,
+ "double_precision" DOUBLE PRECISION NOT NULL DEFAULT 0.0,
+ "integer" INTEGER NOT NULL DEFAULT 0,
+ "big_integer" BIGINT NOT NULL DEFAULT 0,
+ "decimal" DECIMAL(65, 10) NOT NULL DEFAULT 0.0
+ );
+ `},
+ expectedHash: "2181b2da75bb74f7",
+ expectedSchema: schema.Schema{
+ Name: "public",
+ Tables: []schema.Table{
+ {
+ Name: "foo",
+ Columns: []schema.Column{
+ {Name: "varchar", Type: "character varying(128)", Default: "''::character varying", Size: -1, Collation: defaultCollation},
+ {Name: "text", Type: "text", Default: "''::text", Size: -1, Collation: defaultCollation},
+ {Name: "bool", Type: "boolean", Default: "false", Size: 1},
+ {Name: "blob", Type: "bytea", Default: `'\x'::bytea`, Size: -1},
+ {Name: "smallint", Type: "smallint", Default: "0", Size: 2},
+ {Name: "real", Type: "real", Default: "0.0", Size: 4},
+ {Name: "double_precision", Type: "double precision", Default: "0.0", Size: 8},
+ {Name: "integer", Type: "integer", Default: "0", Size: 4},
+ {Name: "big_integer", Type: "bigint", Default: "0", Size: 8},
+ {Name: "decimal", Type: "numeric(65,10)", Default: "0.0", Size: -1},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ },
+ },
+ {
+ name: "Multi-Table",
+ ddl: []string{`
+ CREATE TABLE foo (
+ id INTEGER PRIMARY KEY CHECK (id > 0) NO INHERIT,
+ content TEXT DEFAULT 'some default'
+ );
+ CREATE INDEX foo_idx ON foo(id, content);
+ CREATE TABLE bar(
+ id INTEGER PRIMARY KEY CHECK (id > 0),
+ content TEXT NOT NULL
+ );
+ CREATE INDEX bar_idx ON bar(content, id);
+ CREATE TABLE foobar(
+ id INTEGER PRIMARY KEY,
+ content BIGINT NOT NULL
+ );
+ ALTER TABLE foobar ADD CONSTRAINT foobar_id_check CHECK (id > 0) NOT VALID;
+ CREATE UNIQUE INDEX foobar_idx ON foobar(content);
+ `},
+ expectedHash: "6518bbfe220d4f16",
+ expectedSchema: schema.Schema{
+ Name: "public",
+ Tables: []schema.Table{
+ {
+ Name: "foo",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Size: 4},
+ {Name: "content", Type: "text", IsNullable: true, Default: "'some default'::text", Size: -1, Collation: defaultCollation},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "foo_id_check", Expression: "(id > 0)", IsValid: true},
+ },
+ },
+ {
+ Name: "bar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Size: 4},
+ {Name: "content", Type: "text", Size: -1, Collation: defaultCollation},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "bar_id_check", Expression: "(id > 0)", IsValid: true, IsInheritable: true},
+ },
+ },
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Size: 4},
+ {Name: "content", Type: "bigint", Size: 8},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "foobar_id_check", Expression: "(id > 0)", IsInheritable: true},
+ },
+ },
+ },
+ Indexes: []schema.Index{
+ // foo indexes
+ {
+ TableName: "foo",
+ Name: "foo_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON public.foo USING btree (id)",
+ },
+ {
+ TableName: "foo",
+ Name: "foo_idx", Columns: []string{"id", "content"},
+ GetIndexDefStmt: "CREATE INDEX foo_idx ON public.foo USING btree (id, content)",
+ },
+ // bar indexes
+ {
+ TableName: "bar",
+ Name: "bar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "bar_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX bar_pkey ON public.bar USING btree (id)",
+ },
+ {
+ TableName: "bar",
+ Name: "bar_idx", Columns: []string{"content", "id"},
+ GetIndexDefStmt: "CREATE INDEX bar_idx ON public.bar USING btree (content, id)",
+ },
+ // foobar indexes
+ {
+ TableName: "foobar",
+ Name: "foobar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_pkey ON public.foobar USING btree (id)",
+ },
+ {
+ TableName: "foobar",
+ Name: "foobar_idx", Columns: []string{"content"}, IsUnique: true,
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_idx ON public.foobar USING btree (content)",
+ },
+ },
+ },
+ },
+ {
+ name: "Multi-Schema",
+ ddl: []string{`
+ CREATE TABLE foo (
+ id INTEGER PRIMARY KEY CHECK (id > 0),
+ version INTEGER NOT NULL DEFAULT 0,
+ content TEXT
+ );
+
+ CREATE FUNCTION dup(in int, out f1 int, out f2 text)
+ AS $$ SELECT $1, CAST($1 AS text) || ' is text' $$
+ LANGUAGE SQL;
+
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_trigger
+ BEFORE UPDATE ON foo
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+
+ CREATE SCHEMA test;
+
+ CREATE TABLE test.foo(
+ test_schema_id INTEGER PRIMARY KEY,
+ test_schema_version INTEGER NOT NULL DEFAULT 0,
+ test_schema_content TEXT CHECK (LENGTH(test_schema_content) > 0)
+ );
+
+ CREATE FUNCTION test.dup(in int, out f1 int, out f2 text)
+ AS $$ SELECT $1, ' is int' $$
+ LANGUAGE SQL;
+
+ CREATE FUNCTION test.increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_trigger
+ BEFORE UPDATE ON test.foo
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+
+ CREATE COLLATION test."some collation" (locale = 'en_US');
+
+ CREATE TABLE bar (
+ id INTEGER CHECK (id > 0),
+ author TEXT COLLATE test."some collation",
+ PRIMARY KEY (author, id)
+ ) PARTITION BY LIST (author);
+ CREATE INDEX some_partitioned_idx ON bar(author, id);
+ CREATE TABLE bar_1 PARTITION OF bar FOR VALUES IN ('some author 1');
+
+ CREATE TABLE test.bar (
+ test_id INTEGER CHECK (test_id > 0),
+ test_author TEXT,
+ PRIMARY KEY (test_author, test_id)
+ ) PARTITION BY LIST (test_author);
+ CREATE INDEX some_partitioned_idx ON test.bar(test_author, test_id);
+ CREATE TABLE test.bar_1 PARTITION OF test.bar FOR VALUES IN ('some author 1');
+
+ -- create a trigger on the original schema using a function from the other schema
+ CREATE TRIGGER some_trigger_using_other_schema_function
+ BEFORE UPDATE ON foo
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE test.increment_version();
+ `},
+ expectedHash: "a6a845ad846dc362",
+ expectedSchema: schema.Schema{
+ Name: "public",
+ Tables: []schema.Table{
+ {
+ Name: "foo",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Size: 4},
+ {Name: "version", Type: "integer", Default: "0", Size: 4},
+ {Name: "content", Type: "text", IsNullable: true, Size: -1, Collation: defaultCollation},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "foo_id_check", Expression: "(id > 0)", IsValid: true, IsInheritable: true},
+ },
+ },
+ {
+ Name: "bar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Default: "", Size: 4},
+ {Name: "author", Type: "text", Default: "", Size: -1, Collation: schema.SchemaQualifiedName{SchemaName: "test", EscapedName: `"some collation"`}},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "bar_id_check", Expression: "(id > 0)", IsValid: true, IsInheritable: true},
+ },
+ PartitionKeyDef: "LIST (author)",
+ },
+ {
+ ParentTableName: "bar",
+ Name: "bar_1",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer", Default: "", Size: 4},
+ {Name: "author", Type: "text", Default: "", Size: -1, Collation: schema.SchemaQualifiedName{SchemaName: "test", EscapedName: `"some collation"`}},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some author 1')",
+ },
+ },
+ Indexes: []schema.Index{
+ // foo indexes
+ {
+ TableName: "foo",
+ Name: "foo_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON public.foo USING btree (id)",
+ },
+ // bar indexes
+ {
+ TableName: "bar",
+ Name: "bar_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "bar_pkey", ParentIdxName: "",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX bar_pkey ON ONLY public.bar USING btree (author, id)",
+ },
+ {
+ TableName: "bar",
+ Name: "some_partitioned_idx", Columns: []string{"author", "id"}, IsPk: false, IsUnique: false, ConstraintName: "", ParentIdxName: "",
+ GetIndexDefStmt: "CREATE INDEX some_partitioned_idx ON ONLY public.bar USING btree (author, id)",
+ },
+ // bar_1 indexes
+ {
+ TableName: "bar_1",
+ Name: "bar_1_author_id_idx", Columns: []string{"author", "id"}, IsPk: false, IsUnique: false, ConstraintName: "", ParentIdxName: "some_partitioned_idx",
+ GetIndexDefStmt: "CREATE INDEX bar_1_author_id_idx ON public.bar_1 USING btree (author, id)",
+ },
+ {
+ TableName: "bar_1",
+ Name: "bar_1_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "bar_1_pkey", ParentIdxName: "bar_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX bar_1_pkey ON public.bar_1 USING btree (author, id)",
+ },
+ },
+ Functions: []schema.Function{
+ {
+ SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"dup\"(integer, OUT f1 integer, OUT f2 text)", SchemaName: "public"},
+ FunctionDef: "CREATE OR REPLACE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text)\n RETURNS record\n LANGUAGE sql\nAS $function$ SELECT $1, CAST($1 AS text) || ' is text' $function$\n",
+ Language: "sql",
+ },
+ {
+ SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"},
+ FunctionDef: "CREATE OR REPLACE FUNCTION public.increment_version()\n RETURNS trigger\n LANGUAGE plpgsql\nAS $function$\n\t\t\t\tBEGIN\n\t\t\t\t\tNEW.version = OLD.version + 1;\n\t\t\t\t\tRETURN NEW;\n\t\t\t\tEND;\n\t\t\t$function$\n",
+ Language: "plpgsql",
+ },
+ },
+ Triggers: []schema.Trigger{
+ {
+ EscapedName: "\"some_trigger\"",
+ OwningTable: schema.SchemaQualifiedName{EscapedName: "\"foo\"", SchemaName: "public"},
+ OwningTableUnescapedName: "foo",
+ Function: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"},
+ GetTriggerDefStmt: "CREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()",
+ },
+ },
+ },
+ },
+ {
+ name: "Empty Schema",
+ ddl: nil,
+ expectedHash: "660be155e4c39f8b",
+ expectedSchema: schema.Schema{
+ Name: "public",
+ Tables: nil,
+ },
+ },
+ {
+ name: "No Indexes or constraints",
+ ddl: []string{`
+ CREATE TABLE foo (
+ value TEXT
+ );
+ `},
+ expectedHash: "9db57cf969f0a509",
+ expectedSchema: schema.Schema{
+ Name: "public",
+ Tables: []schema.Table{
+ {
+ Name: "foo",
+ Columns: []schema.Column{
+ {Name: "value", Type: "text", IsNullable: true, Size: -1, Collation: defaultCollation},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ },
+ },
+ }
+)
+
+func TestSchemaTestCases(t *testing.T) {
+ engine, err := pgengine.StartEngine()
+ require.NoError(t, err)
+ defer engine.Close()
+
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ db, err := engine.CreateDatabase()
+ require.NoError(t, err)
+ conn, err := sql.Open("pgx", db.GetDSN())
+ require.NoError(t, err)
+
+ for _, stmt := range testCase.ddl {
+ _, err := conn.Exec(stmt)
+ require.NoError(t, err)
+ }
+
+ fetchedSchema, err := schema.GetPublicSchema(context.TODO(), conn)
+ if testCase.expectedErrIs != nil {
+ require.ErrorIs(t, err, testCase.expectedErrIs)
+ return
+ } else {
+ require.NoError(t, err)
+ }
+
+ expectedNormalized := testCase.expectedSchema.Normalize()
+ fetchedNormalized := fetchedSchema.Normalize()
+ assert.Equal(t, expectedNormalized, fetchedNormalized, "expected=\n%# v \n fetched=%# v\n", pretty.Formatter(expectedNormalized), pretty.Formatter(fetchedNormalized))
+
+ fetchedSchemaHash, err := fetchedSchema.Hash()
+ require.NoError(t, err)
+ expectedSchemaHash, err := testCase.expectedSchema.Hash()
+ require.NoError(t, err)
+ assert.Equal(t, testCase.expectedHash, fetchedSchemaHash)
+ // same schemas should have the same hashes
+ assert.Equal(t, expectedSchemaHash, fetchedSchemaHash, "hash of expected schema should match fetched hash")
+
+ require.NoError(t, conn.Close())
+ require.NoError(t, db.DropDB())
+ })
+ }
+}
+
+func TestIdxDefStmtToCreateIdxConcurrently(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ defStmt string
+ out string
+ expectErr bool
+ }{
+ {
+ name: "simple index",
+ defStmt: `CREATE INDEX foobar ON public.foobar USING btree (foo)`,
+ out: `CREATE INDEX CONCURRENTLY foobar ON public.foobar USING btree (foo)`,
+ },
+ {
+ name: "unique index",
+ defStmt: `CREATE UNIQUE INDEX foobar ON public.foobar USING btree (foo)`,
+ out: `CREATE UNIQUE INDEX CONCURRENTLY foobar ON public.foobar USING btree (foo)`,
+ },
+ {
+ name: "malicious name index",
+ defStmt: `CREATE UNIQUE INDEX "CREATE INDEX ON" ON public.foobar USING btree (foo)`,
+ out: `CREATE UNIQUE INDEX CONCURRENTLY "CREATE INDEX ON" ON public.foobar USING btree (foo)`,
+ },
+ {
+ name: "case sensitive",
+ defStmt: `CREATE uNIQUE INDEX foobar ON public.foobar USING btree (foo)`,
+ expectErr: true,
+ },
+ {
+ name: "errors with random start character",
+ defStmt: `ALTER TABLE CREATE UNIQUE INDEX foobar ON public.foobar USING btree (foo)`,
+ expectErr: true,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ out, err := schema.GetIndexDefStatement(tc.defStmt).ToCreateIndexConcurrently()
+ if tc.expectErr {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, tc.out, out)
+ }
+ })
+ }
+}
+
+func TestTriggerDefStmtToCreateOrReplace(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ defStmt string
+ out string
+ expectErr bool
+ }{
+ {
+ name: "simple trigger",
+ defStmt: "CREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()",
+ out: "CREATE OR REPLACE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()",
+ },
+ {
+ name: "malicious name trigger",
+ defStmt: `CREATE TRIGGER "CREATE TRIGGER" BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()`,
+ out: `CREATE OR REPLACE TRIGGER "CREATE TRIGGER" BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()`,
+ },
+ {
+ name: "case sensitive",
+ defStmt: "cREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()",
+ expectErr: true,
+ },
+ {
+ name: "errors with random start character",
+ defStmt: "ALTER TRIGGER CREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()",
+ expectErr: true,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ out, err := schema.GetTriggerDefStatement(tc.defStmt).ToCreateOrReplace()
+ if tc.expectErr {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, tc.out, out)
+ }
+ })
+ }
+}
diff --git a/pkg/diff/diff.go b/pkg/diff/diff.go
new file mode 100644
index 0000000..be9cfb5
--- /dev/null
+++ b/pkg/diff/diff.go
@@ -0,0 +1,323 @@
+package diff
+
+import (
+ "fmt"
+ "sort"
+
+ "github.com/stripe/pg-schema-diff/internal/graph"
+ "github.com/stripe/pg-schema-diff/internal/schema"
+)
+
+var ErrNotImplemented = fmt.Errorf("not implemented")
+var errDuplicateIdentifier = fmt.Errorf("duplicate identifier")
+
+type diffType string
+
+const (
+ diffTypeDelete diffType = "DELETE"
+ diffTypeAddAlter diffType = "ADDALTER"
+)
+
+type (
+ diff[S schema.Object] interface {
+ GetOld() S
+ GetNew() S
+ }
+
+ // sqlGenerator is used to generate SQL that resolves diffs between lists
+ sqlGenerator[S schema.Object, Diff diff[S]] interface {
+ Add(S) ([]Statement, error)
+ Delete(S) ([]Statement, error)
+ // Alter generates the statements required to resolve the schema object to its new state using the
+ // provided diff. Alter, e.g., with a table, might produce add/delete statements
+ Alter(Diff) ([]Statement, error)
+ }
+
+ // dependency indicates an edge between the SQL to resolve a diff for a source schema object and the SQL to resolve
+ // the diff of a target schema object
+ //
+ // Most SchemaObjects will have two nodes in the SQL graph: a node for delete SQL and a node for add/alter SQL.
+ // These nodes will almost always be present in the sqlGraph even if the schema object is not being deleted (or added/altered).
+ // If a node is present for a schema object where the "diffType" is NOT occurring, it will just be a no-op (no SQl statements)
+ dependency struct {
+ sourceObjId string
+ sourceType diffType
+
+ targetObjId string
+ targetType diffType
+ }
+)
+
+type dependencyBuilder struct {
+ valObjId string
+ valType diffType
+}
+
+func mustRun(schemaObjId string, schemaDiffType diffType) dependencyBuilder {
+ return dependencyBuilder{
+ valObjId: schemaObjId,
+ valType: schemaDiffType,
+ }
+}
+
+func (d dependencyBuilder) before(valObjId string, valType diffType) dependency {
+ return dependency{
+ sourceType: d.valType,
+ sourceObjId: d.valObjId,
+
+ targetType: valType,
+ targetObjId: valObjId,
+ }
+}
+
+func (d dependencyBuilder) after(valObjId string, valType diffType) dependency {
+ return dependency{
+ sourceObjId: valObjId,
+ sourceType: valType,
+
+ targetObjId: d.valObjId,
+ targetType: d.valType,
+ }
+}
+
+// sqlVertexGenerator is used to generate SQL statements for schema objects that have dependency webs
+// with other schema objects. The schema object represents a vertex in the graph.
+type sqlVertexGenerator[S schema.Object, Diff diff[S]] interface {
+ sqlGenerator[S, Diff]
+ // GetSQLVertexId gets the canonical vertex id to represent the schema object
+ GetSQLVertexId(S) string
+
+ // GetAddAlterDependencies gets the dependencies of the SQL generated to resolve the AddAlter diff for the
+ // schema objects. Dependencies can be formed on any other nodes in the SQL graph, even if the node has
+ // no statements. If the diff is just an add, then old will be the zero value
+ //
+ // These dependencies can also be built in reverse: the SQL returned by the sqlVertexGenerator to resolve the
+ // diff for the object must always be run before the SQL required to resolve another SQL vertex diff
+ GetAddAlterDependencies(new S, old S) []dependency
+
+ // GetDeleteDependencies is the same as above but for deletes.
+ // Invariant to maintain:
+ // - If an object X depends on the delete for an object Y (generated by the sqlVertexGenerator), immediately after the
+ // the (Y, diffTypeDelete) sqlVertex's SQL is run, Y must no longer be present in the schema; either the
+ // (Y, diffTypeDelete) statements deleted Y or something that vertex depended on deleted Y. In other words, if a
+ // delete is cascaded by another delete (e.g., index dropped by table drop) and the index SQL is empty,
+ // the index delete vertex must still have dependency from itself to the object from which the delete cascades down from
+ GetDeleteDependencies(S) []dependency
+}
+
+type (
+ // listDiff represents the differences between two lists.
+ listDiff[S schema.Object, Diff diff[S]] struct {
+ adds []S
+ deletes []S
+ // alters contains the diffs of any objects that persisted between two schemas
+ alters []Diff
+ }
+
+ sqlGroupedByEffect[S schema.Object, Diff diff[S]] struct {
+ Adds []Statement
+ Deletes []Statement
+ // Alters might contain adds and deletes. For example, a set of alters for a table might add indexes.
+ Alters []Statement
+ }
+)
+
+func (ld listDiff[S, D]) isEmpty() bool {
+ return len(ld.adds) == 0 || len(ld.alters) == 0 || len(ld.deletes) == 0
+}
+
+func (ld listDiff[S, D]) resolveToSQLGroupedByEffect(sqlGenerator sqlGenerator[S, D]) (sqlGroupedByEffect[S, D], error) {
+ var adds, deletes, alters []Statement
+
+ for _, a := range ld.adds {
+ statements, err := sqlGenerator.Add(a)
+ if err != nil {
+ return sqlGroupedByEffect[S, D]{}, fmt.Errorf("generating SQL for add %s: %w", a.GetName(), err)
+ }
+ adds = append(adds, statements...)
+ }
+ for _, d := range ld.deletes {
+ statements, err := sqlGenerator.Delete(d)
+ if err != nil {
+ return sqlGroupedByEffect[S, D]{}, fmt.Errorf("generating SQL for delete %s: %w", d.GetName(), err)
+ }
+ deletes = append(deletes, statements...)
+ }
+ for _, a := range ld.alters {
+ statements, err := sqlGenerator.Alter(a)
+ if err != nil {
+ return sqlGroupedByEffect[S, D]{}, fmt.Errorf("generating SQL for diff %+v: %w", a, err)
+ }
+ alters = append(alters, statements...)
+ }
+
+ return sqlGroupedByEffect[S, D]{
+ Adds: adds,
+ Deletes: deletes,
+ Alters: alters,
+ }, nil
+}
+
+func (ld listDiff[S, D]) resolveToSQLGraph(generator sqlVertexGenerator[S, D]) (*sqlGraph, error) {
+ graph := graph.NewGraph[sqlVertex]()
+
+ for _, a := range ld.adds {
+ statements, err := generator.Add(a)
+ if err != nil {
+ return nil, fmt.Errorf("generating SQL for add %s: %w", a.GetName(), err)
+ }
+
+ if err := addSQLVertexToGraph(graph, sqlVertex{
+ ObjId: generator.GetSQLVertexId(a),
+ Statements: statements,
+ DiffType: diffTypeAddAlter,
+ }, generator.GetAddAlterDependencies(a, *new(S))); err != nil {
+ return nil, fmt.Errorf("adding SQL Vertex for add %s: %w", a.GetName(), err)
+ }
+ }
+
+ for _, a := range ld.alters {
+ statements, err := generator.Alter(a)
+ if err != nil {
+ return nil, fmt.Errorf("generating SQL for diff %+v: %w", a, err)
+ }
+
+ vertexId := generator.GetSQLVertexId(a.GetOld())
+ vertexIdAfterAlter := generator.GetSQLVertexId(a.GetNew())
+ if vertexIdAfterAlter != vertexId {
+ return nil, fmt.Errorf("an alter lead to a node with a different id: old=%s, new=%s", vertexId, vertexIdAfterAlter)
+ }
+
+ if err := addSQLVertexToGraph(graph, sqlVertex{
+ ObjId: vertexId,
+ Statements: statements,
+ DiffType: diffTypeAddAlter,
+ }, generator.GetAddAlterDependencies(a.GetNew(), a.GetNew())); err != nil {
+ return nil, fmt.Errorf("adding SQL Vertex for alter %s: %w", a.GetOld().GetName(), err)
+ }
+ }
+
+ for _, d := range ld.deletes {
+ statements, err := generator.Delete(d)
+ if err != nil {
+ return nil, fmt.Errorf("generating SQL for delete %s: %w", d.GetName(), err)
+ }
+
+ if err := addSQLVertexToGraph(graph, sqlVertex{
+ ObjId: generator.GetSQLVertexId(d),
+ Statements: statements,
+ DiffType: diffTypeDelete,
+ }, generator.GetDeleteDependencies(d)); err != nil {
+ return nil, fmt.Errorf("adding SQL Vertex for delete %s: %w", d.GetName(), err)
+ }
+ }
+
+ return (*sqlGraph)(graph), nil
+}
+
+func addSQLVertexToGraph(graph *graph.Graph[sqlVertex], vertex sqlVertex, dependencies []dependency) error {
+ // It's possible the node already exists. merge it if it does
+ if graph.HasVertexWithId(vertex.GetId()) {
+ vertex = mergeSQLVertices(graph.GetVertex(vertex.GetId()), vertex)
+ }
+ graph.AddVertex(vertex)
+ for _, dep := range dependencies {
+ if err := addDependency(graph, dep); err != nil {
+ return fmt.Errorf("adding dependencies for %s: %w", vertex.GetId(), err)
+ }
+ }
+ return nil
+}
+
+func addDependency(graph *graph.Graph[sqlVertex], dep dependency) error {
+ sourceVertex := sqlVertex{
+ ObjId: dep.sourceObjId,
+ DiffType: dep.sourceType,
+ Statements: nil,
+ }
+ targetVertex := sqlVertex{
+ ObjId: dep.targetObjId,
+ DiffType: dep.targetType,
+ Statements: nil,
+ }
+
+ // To maintain the correctness of the graph, we will add a dummy vertex for the missing dependencies
+ addVertexIfNotExists(graph, sourceVertex)
+ addVertexIfNotExists(graph, targetVertex)
+
+ if err := graph.AddEdge(sourceVertex.GetId(), targetVertex.GetId()); err != nil {
+ return fmt.Errorf("adding edge from %s to %s: %w", sourceVertex.GetId(), targetVertex.GetId(), err)
+ }
+
+ return nil
+}
+
+func addVertexIfNotExists(graph *graph.Graph[sqlVertex], vertex sqlVertex) {
+ if !graph.HasVertexWithId(vertex.GetId()) {
+ graph.AddVertex(vertex)
+ }
+}
+
+type schemaObjectEntry[S schema.Object] struct {
+ index int // index is the index the schema object in the list
+ obj S
+}
+
+// diffLists diffs two lists of schema objects using name.
+// If an object is present in both lists, it will use buildDiff function to build the diffs between the two objects. If
+// build diff returns as requiresRecreation, then the old schema object will be deleted and the new one will be added
+//
+// The List will outputted in a deterministic order by schema object name, which is important for tests
+func diffLists[S schema.Object, Diff diff[S]](
+ oldSchemaObjs, newSchemaObjs []S,
+ buildDiff func(old, new S, oldIndex, newIndex int) (diff Diff, requiresRecreation bool, error error),
+) (listDiff[S, Diff], error) {
+ nameToOld := make(map[string]schemaObjectEntry[S])
+ for oldIndex, oldSchemaObject := range oldSchemaObjs {
+ if _, nameAlreadyTaken := nameToOld[oldSchemaObject.GetName()]; nameAlreadyTaken {
+ return listDiff[S, Diff]{}, fmt.Errorf("multiple objects have identifier %s: %w", oldSchemaObject.GetName(), errDuplicateIdentifier)
+ }
+ // store the old schema object and its index. if an alteration, the index might be used in the diff, e.g., for columns
+ nameToOld[oldSchemaObject.GetName()] = schemaObjectEntry[S]{
+ obj: oldSchemaObject,
+ index: oldIndex,
+ }
+ }
+
+ var adds []S
+ var alters []Diff
+ var deletes []S
+ for newIndex, newSchemaObj := range newSchemaObjs {
+ if oldSchemaObjAndIndex, hasOldSchemaObj := nameToOld[newSchemaObj.GetName()]; !hasOldSchemaObj {
+ adds = append(adds, newSchemaObj)
+ } else {
+ delete(nameToOld, newSchemaObj.GetName())
+
+ diff, requiresRecreation, err := buildDiff(oldSchemaObjAndIndex.obj, newSchemaObj, oldSchemaObjAndIndex.index, newIndex)
+ if err != nil {
+ return listDiff[S, Diff]{}, fmt.Errorf("diffing for %s: %w", newSchemaObj.GetName(), err)
+ }
+ if requiresRecreation {
+ deletes = append(deletes, oldSchemaObjAndIndex.obj)
+ adds = append(adds, newSchemaObj)
+ } else {
+ alters = append(alters, diff)
+ }
+ }
+ }
+
+ // Remaining schema objects in nameToOld have been deleted
+ for _, d := range nameToOld {
+ deletes = append(deletes, d.obj)
+ }
+ // Iterating through a map is non-deterministic in go, so we'll sort the deletes by schema object name
+ sort.Slice(deletes, func(i, j int) bool {
+ return deletes[i].GetName() < deletes[j].GetName()
+ })
+
+ return listDiff[S, Diff]{
+ adds: adds,
+ deletes: deletes,
+ alters: alters,
+ }, nil
+}
diff --git a/pkg/diff/plan.go b/pkg/diff/plan.go
new file mode 100644
index 0000000..589903d
--- /dev/null
+++ b/pkg/diff/plan.go
@@ -0,0 +1,69 @@
+package diff
+
+import (
+ "fmt"
+ "regexp"
+ "time"
+)
+
+type MigrationHazardType = string
+
+const (
+ MigrationHazardTypeAcquiresAccessExclusiveLock MigrationHazardType = "ACQUIRES_ACCESS_EXCLUSIVE_LOCK"
+ MigrationHazardTypeAcquiresShareLock MigrationHazardType = "ACQUIRES_SHARE_LOCK"
+ MigrationHazardTypeDeletesData MigrationHazardType = "DELETES_DATA"
+ MigrationHazardTypeHasUntrackableDependencies MigrationHazardType = "HAS_UNTRACKABLE_DEPENDENCIES"
+ MigrationHazardTypeIndexBuild MigrationHazardType = "INDEX_BUILD"
+ MigrationHazardTypeIndexDropped MigrationHazardType = "INDEX_DROPPED"
+ MigrationHazardTypeImpactsDatabasePerformance MigrationHazardType = "IMPACTS_DATABASE_PERFORMANCE"
+ MigrationHazardTypeIsUserGenerated MigrationHazardType = "IS_USER_GENERATED"
+)
+
+type MigrationHazard struct {
+ Type MigrationHazardType
+ Message string
+}
+
+func (p MigrationHazard) String() string {
+ return fmt.Sprintf("%s: %s", p.Type, p.Message)
+}
+
+type Statement struct {
+ DDL string
+ Timeout time.Duration
+ Hazards []MigrationHazard
+}
+
+func (s Statement) ToSQL() string {
+ return s.DDL + ";"
+}
+
+type Plan struct {
+ Statements []Statement
+ CurrentSchemaHash string
+}
+
+func (p Plan) ApplyStatementTimeoutModifier(regex *regexp.Regexp, timeout time.Duration) Plan {
+ var modifiedStmts []Statement
+ for _, stmt := range p.Statements {
+ if regex.MatchString(stmt.DDL) {
+ stmt.Timeout = timeout
+ }
+ modifiedStmts = append(modifiedStmts, stmt)
+ }
+ p.Statements = modifiedStmts
+ return p
+}
+
+func (p Plan) InsertStatement(index int, statement Statement) (Plan, error) {
+ if index < 0 || index > len(p.Statements) {
+ return Plan{}, fmt.Errorf("index must be >= 0 and <= %d", len(p.Statements))
+ }
+ if index == len(p.Statements) {
+ p.Statements = append(p.Statements, statement)
+ return p, nil
+ }
+ p.Statements = append(p.Statements[:index+1], p.Statements[index:]...)
+ p.Statements[index] = statement
+ return p, nil
+}
diff --git a/pkg/diff/plan_generator.go b/pkg/diff/plan_generator.go
new file mode 100644
index 0000000..7025d95
--- /dev/null
+++ b/pkg/diff/plan_generator.go
@@ -0,0 +1,244 @@
+package diff
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+
+ _ "github.com/jackc/pgx/v4/stdlib"
+ "github.com/kr/pretty"
+ "github.com/stripe/pg-schema-diff/internal/schema"
+
+ "github.com/stripe/pg-schema-diff/internal/queries"
+ "github.com/stripe/pg-schema-diff/pkg/log"
+ "github.com/stripe/pg-schema-diff/pkg/tempdb"
+)
+
+type (
+ planOptions struct {
+ dataPackNewTables bool
+ ignoreChangesToColOrder bool
+ logger log.Logger
+ validatePlan bool
+ }
+
+ PlanOpt func(opts *planOptions)
+)
+
+// WithDataPackNewTables configures the plan generation such that it packs the columns in the new tables to minimize
+// padding. It will help minimize the storage used by the tables
+func WithDataPackNewTables() PlanOpt {
+ return func(opts *planOptions) {
+ opts.dataPackNewTables = true
+ }
+}
+
+// WithIgnoreChangesToColOrder configures the plan generation to ignore any changes to the ordering of columns in
+// existing tables. You will most likely want this enabled
+func WithIgnoreChangesToColOrder() PlanOpt {
+ return func(opts *planOptions) {
+ opts.ignoreChangesToColOrder = true
+ }
+}
+
+func WithDoNotValidatePlan() PlanOpt {
+ return func(opts *planOptions) {
+ opts.validatePlan = false
+ }
+}
+
+func WithLogger(logger log.Logger) PlanOpt {
+ return func(opts *planOptions) {
+ opts.logger = logger
+ }
+}
+
+func GeneratePlan(ctx context.Context, conn *sql.Conn, tempDbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) {
+ planOptions := &planOptions{
+ validatePlan: true,
+ logger: log.SimpleLogger(),
+ }
+ for _, opt := range opts {
+ opt(planOptions)
+ }
+
+ currentSchema, err := schema.GetPublicSchema(ctx, conn)
+ if err != nil {
+ return Plan{}, fmt.Errorf("getting current schema: %w", err)
+ }
+ newSchema, err := deriveSchemaFromDDLOnTempDb(ctx, planOptions.logger, tempDbFactory, newDDL)
+ if err != nil {
+ return Plan{}, fmt.Errorf("getting new schema: %w", err)
+ }
+
+ statements, err := generateMigrationStatements(currentSchema, newSchema, planOptions)
+ if err != nil {
+ return Plan{}, fmt.Errorf("generating plan statements: %w", err)
+ }
+
+ hash, err := currentSchema.Hash()
+ if err != nil {
+ return Plan{}, fmt.Errorf("generating current schema hash: %w", err)
+ }
+
+ plan := Plan{
+ Statements: statements,
+ CurrentSchemaHash: hash,
+ }
+
+ if planOptions.validatePlan {
+ if err := assertValidPlan(ctx, tempDbFactory, currentSchema, newSchema, plan, planOptions); err != nil {
+ return Plan{}, fmt.Errorf("validating migration plan: %w \n%# v", err, pretty.Formatter(plan))
+ }
+ }
+
+ return plan, nil
+}
+
+func deriveSchemaFromDDLOnTempDb(ctx context.Context, logger log.Logger, tempDbFactory tempdb.Factory, ddl []string) (schema.Schema, error) {
+ tempDb, dropTempDb, err := tempDbFactory.Create(ctx)
+ if err != nil {
+ return schema.Schema{}, fmt.Errorf("creating temp database: %w", err)
+ }
+ defer func(drop tempdb.Dropper) {
+ if err := drop(ctx); err != nil {
+ logger.Errorf("an error occurred while dropping the temp database: %s", err)
+ }
+ }(dropTempDb)
+
+ for _, stmt := range ddl {
+ if _, err := tempDb.ExecContext(ctx, stmt); err != nil {
+ return schema.Schema{}, fmt.Errorf("running DDL: %w", err)
+ }
+ }
+
+ return schema.GetPublicSchema(ctx, tempDb)
+}
+
+func generateMigrationStatements(oldSchema, newSchema schema.Schema, planOptions *planOptions) ([]Statement, error) {
+ diff, _, err := buildSchemaDiff(oldSchema, newSchema)
+ if err != nil {
+ return nil, err
+ }
+
+ if planOptions.dataPackNewTables {
+ // Instead of enabling ignoreChangesToColOrder by default, force the user to enable ignoreChangesToColOrder.
+ // This ensures the user knows what's going on behind-the-scenes
+ if !planOptions.ignoreChangesToColOrder {
+ return nil, fmt.Errorf("cannot data pack new tables without also ignoring changes to column order")
+ }
+ diff = dataPackNewTables(diff)
+ }
+ if planOptions.ignoreChangesToColOrder {
+ diff = removeChangesToColumnOrdering(diff)
+ }
+
+ statements, err := diff.resolveToSQL()
+ if err != nil {
+ return nil, fmt.Errorf("generating migration statements: %w", err)
+ }
+ return statements, nil
+}
+
+func assertValidPlan(ctx context.Context,
+ tempDbFactory tempdb.Factory,
+ currentSchema, newSchema schema.Schema,
+ plan Plan,
+ planOptions *planOptions,
+) error {
+ tempDb, dropTempDb, err := tempDbFactory.Create(ctx)
+ if err != nil {
+ return err
+ }
+ defer func(drop tempdb.Dropper) {
+ if err := drop(ctx); err != nil {
+ planOptions.logger.Errorf("an error occurred while dropping the temp database: %s", err)
+ }
+ }(dropTempDb)
+
+ tempDbConn, err := tempDb.Conn(ctx)
+ if err != nil {
+ return fmt.Errorf("opening database connection: %w", err)
+ }
+ defer tempDbConn.Close()
+
+ if err := setSchemaForEmptyDatabase(ctx, tempDbConn, currentSchema); err != nil {
+ return fmt.Errorf("inserting schema in temporary database: %w", err)
+ }
+
+ if err := executeStatements(ctx, tempDbConn, plan.Statements); err != nil {
+ return fmt.Errorf("running migration plan: %w", err)
+ }
+
+ migratedSchema, err := schema.GetPublicSchema(ctx, tempDbConn)
+ if err != nil {
+ return fmt.Errorf("fetching schema from migrated database: %w", err)
+ }
+
+ return assertMigratedSchemaMatchesTarget(migratedSchema, newSchema, planOptions)
+}
+
+func setSchemaForEmptyDatabase(ctx context.Context, conn *sql.Conn, dbSchema schema.Schema) error {
+ // We can't create invalid indexes. We'll mark them valid in the schema, which should be functionally
+ // equivalent for the sake of DDL and other statements.
+ //
+ // Make a new array, so we don't mutate the underlying array of the original schema. Ideally, we have a clone function
+ // in the future
+ var validIndexes []schema.Index
+ for _, idx := range dbSchema.Indexes {
+ idx.IsInvalid = false
+ validIndexes = append(validIndexes, idx)
+ }
+ dbSchema.Indexes = validIndexes
+
+ if statements, err := generateMigrationStatements(schema.Schema{
+ Name: "public",
+ Tables: nil,
+ Indexes: nil,
+ Functions: nil,
+ Triggers: nil,
+ }, dbSchema, &planOptions{}); err != nil {
+ return fmt.Errorf("building schema diff: %w", err)
+ } else {
+ return executeStatements(ctx, conn, statements)
+ }
+}
+
+func assertMigratedSchemaMatchesTarget(migratedSchema, targetSchema schema.Schema, planOptions *planOptions) error {
+ toTargetSchemaStmts, err := generateMigrationStatements(migratedSchema, targetSchema, planOptions)
+ if err != nil {
+ return fmt.Errorf("building schema diff between migrated database and new schema: %w", err)
+ }
+
+ if len(toTargetSchemaStmts) > 0 {
+ var stmtsStrs []string
+ for _, stmt := range toTargetSchemaStmts {
+ stmtsStrs = append(stmtsStrs, stmt.DDL)
+ }
+ return fmt.Errorf("diff detected:\n%s", strings.Join(stmtsStrs, "\n"))
+ }
+
+ return nil
+}
+
+// executeStatements executes the statements using the sql connection. It will modify the session-level
+// statement timeout of the underlying connection.
+func executeStatements(ctx context.Context, conn queries.DBTX, statements []Statement) error {
+ // Due to the way *sql.Db works, when a statement_timeout is set for the session, it will NOT reset
+ // by default when it's returned to the pool.
+ //
+ // We can't set the timeout at the TRANSACTION-level (for each transaction) because `ADD INDEX CONCURRENTLY`
+ // must be executed within its own transaction block. Postgres will error if you try to set a TRANSACTION-level
+ // timeout for it. SESSION-level statement_timeouts are respected by `ADD INDEX CONCURRENTLY`
+ for _, stmt := range statements {
+ if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION statement_timeout = %d", stmt.Timeout.Milliseconds())); err != nil {
+ return fmt.Errorf("setting statement timeout: %w", err)
+ }
+ if _, err := conn.ExecContext(ctx, stmt.ToSQL()); err != nil {
+ // could the migration statement contain sensitive information?
+ return fmt.Errorf("executing migration statement: %s: %w", stmt, err)
+ }
+ }
+ return nil
+}
diff --git a/pkg/diff/plan_generator_test.go b/pkg/diff/plan_generator_test.go
new file mode 100644
index 0000000..c29b332
--- /dev/null
+++ b/pkg/diff/plan_generator_test.go
@@ -0,0 +1,125 @@
+package diff_test
+
+import (
+ "context"
+ "database/sql"
+ "io"
+ "testing"
+
+ _ "github.com/jackc/pgx/v4/stdlib"
+ "github.com/stretchr/testify/suite"
+ "github.com/stripe/pg-schema-diff/internal/pgengine"
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+
+ "github.com/stripe/pg-schema-diff/pkg/tempdb"
+)
+
+type simpleMigratorTestSuite struct {
+ suite.Suite
+
+ pgEngine *pgengine.Engine
+ db *pgengine.DB
+}
+
+func (suite *simpleMigratorTestSuite) mustGetTestDBPool() *sql.DB {
+ pool, err := sql.Open("pgx", suite.db.GetDSN())
+ suite.NoError(err)
+ return pool
+}
+
+func (suite *simpleMigratorTestSuite) mustGetTestDBConn() (conn *sql.Conn, poolCloser io.Closer) {
+ pool := suite.mustGetTestDBPool()
+ conn, err := pool.Conn(context.Background())
+ suite.Require().NoError(err)
+ return conn, pool
+}
+
+func (suite *simpleMigratorTestSuite) mustBuildTempDbFactory(ctx context.Context) tempdb.Factory {
+ tempDbFactory, err := tempdb.NewOnInstanceFactory(ctx, func(ctx context.Context, dbName string) (*sql.DB, error) {
+ return sql.Open("pgx", suite.pgEngine.GetPostgresDatabaseConnOpts().With("dbname", dbName).ToDSN())
+ })
+ suite.Require().NoError(err)
+ return tempDbFactory
+}
+
+func (suite *simpleMigratorTestSuite) mustApplyDDLToTestDb(ddl []string) {
+ conn := suite.mustGetTestDBPool()
+ defer conn.Close()
+
+ for _, stmt := range ddl {
+ _, err := conn.Exec(stmt)
+ suite.NoError(err)
+ }
+}
+
+func (suite *simpleMigratorTestSuite) SetupSuite() {
+ engine, err := pgengine.StartEngine()
+ suite.Require().NoError(err)
+ suite.pgEngine = engine
+}
+
+func (suite *simpleMigratorTestSuite) TearDownSuite() {
+ suite.pgEngine.Close()
+}
+
+func (suite *simpleMigratorTestSuite) SetupTest() {
+ db, err := suite.pgEngine.CreateDatabase()
+ suite.NoError(err)
+ suite.db = db
+}
+
+func (suite *simpleMigratorTestSuite) TearDownTest() {
+ suite.db.DropDB()
+}
+
+func (suite *simpleMigratorTestSuite) TestPlanAndApplyMigration() {
+ initialDDL := `
+ CREATE TABLE foobar(
+ id CHAR(16) PRIMARY KEY
+ ); `
+ newSchemaDDL := `
+ CREATE TABLE foobar(
+ id CHAR(16) PRIMARY KEY,
+ new_column VARCHAR(128) NOT NULL
+ );
+ `
+
+ suite.mustApplyDDLToTestDb([]string{initialDDL})
+
+ conn, poolCloser := suite.mustGetTestDBConn()
+ defer poolCloser.Close()
+ defer conn.Close()
+
+ tempDbFactory := suite.mustBuildTempDbFactory(context.Background())
+ defer tempDbFactory.Close()
+
+ plan, err := diff.GeneratePlan(context.Background(), conn, tempDbFactory, []string{newSchemaDDL})
+ suite.NoError(err)
+
+ // Run the migration
+ for _, stmt := range plan.Statements {
+ _, err = conn.ExecContext(context.Background(), stmt.ToSQL())
+ suite.Require().NoError(err)
+ }
+ // Ensure that some sort of migration ran. we're really not testing the correctness of the
+ // migration in this test suite
+ _, err = conn.ExecContext(context.Background(),
+ "SELECT new_column FROM foobar;")
+ suite.NoError(err)
+}
+
+func (suite *simpleMigratorTestSuite) TestCannotPackNewTablesWithoutIgnoringChangesToColumnOrder() {
+ tempDbFactory := suite.mustBuildTempDbFactory(context.Background())
+ defer tempDbFactory.Close()
+
+ conn, poolCloser := suite.mustGetTestDBConn()
+ defer poolCloser.Close()
+ defer conn.Close()
+
+ _, err := diff.GeneratePlan(context.Background(), conn, tempDbFactory, []string{``}, diff.WithDataPackNewTables())
+ suite.ErrorContains(err, "cannot data pack new tables without also ignoring changes to column order")
+}
+
+func TestSimpleMigratorTestSuite(t *testing.T) {
+ suite.Run(t, new(simpleMigratorTestSuite))
+}
diff --git a/pkg/diff/plan_test.go b/pkg/diff/plan_test.go
new file mode 100644
index 0000000..f15b6a4
--- /dev/null
+++ b/pkg/diff/plan_test.go
@@ -0,0 +1,240 @@
+package diff_test
+
+import (
+ "regexp"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stripe/pg-schema-diff/pkg/diff"
+)
+
+func TestPlan_ApplyStatementTimeoutModifier(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ regex string
+ timeout time.Duration
+ plan diff.Plan
+ expectedPlan diff.Plan
+ }{
+ {
+ name: "no matches",
+ regex: "^.*x.*y$",
+ timeout: 2 * time.Hour,
+ plan: diff.Plan{
+ Statements: []diff.Statement{
+ {
+ DDL: "does-not-match-1",
+ Timeout: 3 * time.Second,
+ },
+ {
+ DDL: "does-not-match-2",
+ Timeout: time.Second,
+ },
+ {
+ DDL: "does-not-match-3",
+ Timeout: 2 * time.Second,
+ },
+ },
+ CurrentSchemaHash: "some-hash",
+ },
+ expectedPlan: diff.Plan{
+ Statements: []diff.Statement{
+ {
+ DDL: "does-not-match-1",
+ Timeout: 3 * time.Second,
+ },
+ {
+ DDL: "does-not-match-2",
+ Timeout: time.Second,
+ },
+ {
+ DDL: "does-not-match-3",
+ Timeout: 2 * time.Second,
+ },
+ },
+ CurrentSchemaHash: "some-hash",
+ },
+ },
+ {
+ name: "some match",
+ regex: "^.*x.*y$",
+ timeout: 2 * time.Hour,
+ plan: diff.Plan{
+ Statements: []diff.Statement{
+ {
+ DDL: "some-letters-than-an-x--end-in-y",
+ Timeout: 3 * time.Second,
+ },
+ {
+ DDL: "does-not-match",
+ Timeout: time.Second,
+ },
+ {
+ DDL: "other-letters xerox but ends in finally",
+ Timeout: 2 * time.Second,
+ },
+ },
+ CurrentSchemaHash: "some-hash",
+ },
+ expectedPlan: diff.Plan{
+ Statements: []diff.Statement{
+ {
+ DDL: "some-letters-than-an-x--end-in-y",
+ Timeout: 2 * time.Hour,
+ },
+ {
+ DDL: "does-not-match",
+ Timeout: time.Second,
+ },
+ {
+ DDL: "other-letters xerox but ends in finally",
+ Timeout: 2 * time.Hour,
+ },
+ },
+ CurrentSchemaHash: "some-hash",
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ regex := regexp.MustCompile(tc.regex)
+ resultingPlan := tc.plan.ApplyStatementTimeoutModifier(regex, tc.timeout)
+ assert.Equal(t, tc.expectedPlan, resultingPlan)
+ })
+ }
+}
+
+func TestPlan_InsertStatement(t *testing.T) {
+ var statementToInsert = diff.Statement{
+ DDL: "some DDL",
+ Timeout: 3 * time.Second,
+ Hazards: []diff.MigrationHazard{
+ {Type: diff.MigrationHazardTypeIsUserGenerated, Message: "user-generated"},
+ {Type: diff.MigrationHazardTypeAcquiresShareLock, Message: "acquires share lock"},
+ },
+ }
+
+ for _, tc := range []struct {
+ name string
+ plan diff.Plan
+ index int
+
+ expectedPlan diff.Plan
+ expectedErrContains string
+ }{
+ {
+ name: "empty plan",
+ plan: diff.Plan{
+ Statements: nil,
+ CurrentSchemaHash: "some-hash",
+ },
+ index: 0,
+
+ expectedPlan: diff.Plan{
+ Statements: []diff.Statement{
+ statementToInsert,
+ },
+ CurrentSchemaHash: "some-hash",
+ },
+ },
+ {
+ name: "insert at start",
+ plan: diff.Plan{
+ Statements: []diff.Statement{
+ {DDL: "statement 1", Timeout: time.Second},
+ {DDL: "statement 2", Timeout: 2 * time.Second},
+ {DDL: "statement 3", Timeout: 3 * time.Second},
+ },
+ CurrentSchemaHash: "some-hash",
+ },
+ index: 0,
+
+ expectedPlan: diff.Plan{
+ Statements: []diff.Statement{
+ statementToInsert,
+ {DDL: "statement 1", Timeout: time.Second},
+ {DDL: "statement 2", Timeout: 2 * time.Second},
+ {DDL: "statement 3", Timeout: 3 * time.Second},
+ },
+ CurrentSchemaHash: "some-hash",
+ },
+ },
+ {
+ name: "insert after first statement",
+ plan: diff.Plan{
+ Statements: []diff.Statement{
+ {DDL: "statement 1", Timeout: time.Second},
+ {DDL: "statement 2", Timeout: 2 * time.Second},
+ {DDL: "statement 3", Timeout: 3 * time.Second},
+ },
+ CurrentSchemaHash: "some-hash",
+ },
+ index: 1,
+
+ expectedPlan: diff.Plan{
+ Statements: []diff.Statement{
+
+ {DDL: "statement 1", Timeout: time.Second},
+ statementToInsert,
+ {DDL: "statement 2", Timeout: 2 * time.Second},
+ {DDL: "statement 3", Timeout: 3 * time.Second},
+ },
+ CurrentSchemaHash: "some-hash",
+ },
+ },
+ {
+ name: "insert after last statement statement",
+ plan: diff.Plan{
+ Statements: []diff.Statement{
+ {DDL: "statement 1", Timeout: time.Second},
+ {DDL: "statement 2", Timeout: 2 * time.Second},
+ {DDL: "statement 3", Timeout: 3 * time.Second},
+ },
+ CurrentSchemaHash: "some-hash",
+ },
+ index: 3,
+
+ expectedPlan: diff.Plan{
+ Statements: []diff.Statement{
+ {DDL: "statement 1", Timeout: time.Second},
+ {DDL: "statement 2", Timeout: 2 * time.Second},
+ {DDL: "statement 3", Timeout: 3 * time.Second},
+ statementToInsert,
+ },
+ CurrentSchemaHash: "some-hash",
+ },
+ },
+ {
+ name: "errors on negative index",
+ plan: diff.Plan{
+ Statements: nil,
+ CurrentSchemaHash: "some-hash",
+ },
+ index: -1,
+
+ expectedErrContains: "index must be",
+ },
+ {
+ name: "errors on index greater than len",
+ plan: diff.Plan{
+ Statements: nil,
+ CurrentSchemaHash: "some-hash",
+ },
+ index: 1,
+
+ expectedErrContains: "index must be",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ resultingPlan, err := tc.plan.InsertStatement(tc.index, statementToInsert)
+ if len(tc.expectedErrContains) > 0 {
+ assert.ErrorContains(t, err, tc.expectedErrContains)
+ return
+ }
+ require.NoError(t, err)
+ assert.Equal(t, tc.expectedPlan, resultingPlan)
+ })
+ }
+}
diff --git a/pkg/diff/schema_migration_plan_test.go b/pkg/diff/schema_migration_plan_test.go
new file mode 100644
index 0000000..b9596b0
--- /dev/null
+++ b/pkg/diff/schema_migration_plan_test.go
@@ -0,0 +1,1163 @@
+package diff
+
+import (
+ "testing"
+
+ "github.com/google/uuid"
+ "github.com/kr/pretty"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/stripe/pg-schema-diff/internal/schema"
+)
+
+type schemaMigrationPlanTestCase struct {
+ name string
+ oldSchema schema.Schema
+ newSchema schema.Schema
+ expectedStatements []Statement
+ expectedDiffErrIs error
+}
+
+// schemaMigrationPlanTestCases -- these test cases assert the exact migration plan that is expected
+// to be generated when migrating from the oldSchema to the newSchema. They assert how the migration
+// should occur by asserting the DDL
+//
+// Most test cases should be added to //pg-schema-diff/internal/migration_acceptance_test_cases (acceptance
+// tests) instead of here.
+//
+// The acceptance tests actually fetch the old/new schemas; run the migration; and validate the migration
+// updates the old schema to be equivalent to the new schema. However, they do not assert any DDL; they have
+// no expectation on how the migration should be done.
+//
+// The tests added here should just cover niche cases where you want to assert HOW the migration should be done (e.g.,
+// adding an index concurrently)
+var (
+ defaultCollation = schema.SchemaQualifiedName{
+ EscapedName: `"default"`,
+ SchemaName: "pg_catalog",
+ }
+ cCollation = schema.SchemaQualifiedName{
+ EscapedName: `"C"`,
+ SchemaName: "pg_catalog",
+ }
+
+ schemaMigrationPlanTestCases = []schemaMigrationPlanTestCase{
+ {
+ name: "No-op",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ {Name: "fizz", Type: "boolean", Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "id_check", Expression: "(id > 0)", IsInheritable: true},
+ {Name: "bar_check", Expression: "(id > LENGTH(foo))", IsValid: true},
+ },
+ },
+ {
+ Name: "bar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "character varying(255)", Default: "", IsNullable: false, Collation: defaultCollation},
+ {Name: "foo", Type: "integer", Default: "", IsNullable: true},
+ {Name: "bar", Type: "double precision", Default: "8.8", IsNullable: false},
+ {Name: "fizz", Type: "timestamp with time zone", Default: "CURRENT_TIMESTAMP", IsNullable: true},
+ {Name: "buzz", Type: "real", Default: "", IsNullable: false},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "id_check", Expression: "(id + 10 > 0 )", IsValid: true},
+ {Name: "bar_check", Expression: "(foo > buzz)", IsInheritable: true},
+ },
+ },
+ {
+ Name: "fizz",
+ Columns: nil,
+ CheckConstraints: nil,
+ },
+ },
+ Indexes: []schema.Index{
+ // bar indexes
+ {
+ TableName: "bar",
+ Name: "bar_pkey", Columns: []string{"bar"}, IsPk: true, IsUnique: true, ConstraintName: "bar_pkey_non_default_name",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX bar_pkey ON public.bar USING btree (bar)",
+ },
+ {
+ TableName: "bar",
+ Name: "bar_normal_idx", Columns: []string{"fizz"},
+ GetIndexDefStmt: "CREATE INDEX bar_normal_idx ON public.bar USING btree (fizz)",
+ },
+ {
+ TableName: "bar",
+ Name: "bar_unique_idx", IsUnique: true, Columns: []string{"fizz", "buzz"},
+ GetIndexDefStmt: "CREATE UNIQUE INDEX bar_unique_idx ON public.bar USING btree (fizz, buzz)",
+ },
+ // foobar indexes
+ {
+ TableName: "foobar",
+ Name: "foobar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON public.foobar USING btree (id)",
+ },
+ {
+ TableName: "foobar",
+ Name: "foobar_normal_idx", Columns: []string{"fizz"},
+ GetIndexDefStmt: "CREATE INDEX foobar_normal_idx ON public.foobar USING btree (fizz)",
+ },
+ {
+ TableName: "foobar",
+ Name: "foobar_unique_idx", IsUnique: true, Columns: []string{"foo", "bar"},
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_unique_idx ON public.foobar USING btree (foo, bar)",
+ },
+ },
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ {Name: "fizz", Type: "boolean", Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "id_check", Expression: "(id > 0)", IsInheritable: true},
+ {Name: "bar_check", Expression: "(id > LENGTH(foo))", IsValid: true},
+ },
+ },
+ {
+ Name: "bar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "character varying(255)", Default: "", IsNullable: false, Collation: defaultCollation},
+ {Name: "foo", Type: "integer", Default: "", IsNullable: true},
+ {Name: "bar", Type: "double precision", Default: "8.8", IsNullable: false},
+ {Name: "fizz", Type: "timestamp with time zone", Default: "CURRENT_TIMESTAMP", IsNullable: true},
+ {Name: "buzz", Type: "real", Default: "", IsNullable: false},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "id_check", Expression: "(id + 10 > 0 )", IsValid: true},
+ {Name: "bar_check", Expression: "(foo > buzz)", IsInheritable: true},
+ },
+ },
+ {
+ Name: "fizz",
+ Columns: nil,
+ CheckConstraints: nil,
+ },
+ },
+ Indexes: []schema.Index{
+ // bar indexes
+ {
+ TableName: "bar",
+ Name: "bar_pkey", Columns: []string{"bar"}, IsPk: true, IsUnique: true, ConstraintName: "bar_pkey_non_default_name",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX bar_pkey ON public.bar USING btree (bar)",
+ },
+ {
+ TableName: "bar",
+ Name: "bar_normal_idx", Columns: []string{"fizz"},
+ GetIndexDefStmt: "CREATE INDEX bar_normal_idx ON public.bar USING btree (fizz)",
+ },
+ {
+ TableName: "bar",
+ Name: "bar_unique_idx", IsUnique: true, Columns: []string{"fizz", "buzz"},
+ GetIndexDefStmt: "CREATE UNIQUE INDEX bar_unique_idx ON public.bar USING btree (fizz, buzz)",
+ },
+ // foobar indexes
+ {
+ TableName: "foobar",
+ Name: "foobar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON public.foobar USING btree (id)",
+ },
+ {
+ TableName: "foobar",
+ Name: "foobar_normal_idx", Columns: []string{"fizz"},
+ GetIndexDefStmt: "CREATE INDEX foobar_normal_idx ON public.foobar USING btree (fizz)",
+ },
+ {
+ TableName: "foobar",
+ Name: "foobar_unique_idx", IsUnique: true, Columns: []string{"foo", "bar"},
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_unique_idx ON public.foobar USING btree (foo, bar)",
+ },
+ },
+ },
+ expectedStatements: nil,
+ },
+ {
+ name: "Index replacement",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ Indexes: []schema.Index{
+ {
+ TableName: "foobar",
+ Name: "foo_idx", Columns: []string{"foo"},
+ GetIndexDefStmt: "CREATE INDEX foo_idx ON public.foobar USING btree (foo)",
+ },
+ {
+ TableName: "foobar",
+ Name: "replaced_with_same_name_idx", Columns: []string{"bar"},
+ GetIndexDefStmt: "CREATE INDEX replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)",
+ },
+ },
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ Indexes: []schema.Index{
+ {
+ TableName: "foobar",
+ Name: "new_foo_idx", Columns: []string{"foo"},
+ GetIndexDefStmt: "CREATE INDEX new_foo_idx ON public.foobar USING btree (foo)",
+ },
+ {
+ TableName: "foobar",
+ Name: "replaced_with_same_name_idx", Columns: []string{"bar", "foo"},
+ GetIndexDefStmt: "CREATE INDEX replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)",
+ },
+ },
+ },
+ expectedStatements: []Statement{
+ {
+ DDL: "ALTER INDEX \"replaced_with_same_name_idx\" RENAME TO \"replaced_with_same_name_id_00010203-0405-4607-8809-0a0b0c0d0e0f\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: nil,
+ },
+ {
+ DDL: "CREATE INDEX CONCURRENTLY new_foo_idx ON public.foobar USING btree (foo)",
+ Timeout: statementTimeoutConcurrentIndexBuild,
+ Hazards: []MigrationHazard{buildIndexBuildHazard()},
+ },
+ {
+ DDL: "CREATE INDEX CONCURRENTLY replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)",
+ Timeout: statementTimeoutConcurrentIndexBuild,
+ Hazards: []MigrationHazard{buildIndexBuildHazard()},
+ },
+ {
+ DDL: "DROP INDEX CONCURRENTLY \"foo_idx\"",
+ Timeout: statementTimeoutConcurrentIndexDrop,
+ Hazards: []MigrationHazard{
+ {Type: "INDEX_DROPPED", Message: "Dropping this index means queries that use this index might perform worse because they will no longer will be able to leverage it."},
+ },
+ },
+ {
+ DDL: "DROP INDEX CONCURRENTLY \"replaced_with_same_name_id_00010203-0405-4607-8809-0a0b0c0d0e0f\"",
+ Timeout: statementTimeoutConcurrentIndexDrop,
+ Hazards: []MigrationHazard{
+ {Type: "INDEX_DROPPED", Message: "Dropping this index means queries that use this index might perform worse because they will no longer will be able to leverage it."},
+ },
+ },
+ },
+ },
+ {
+ name: "Index dropped concurrently before columns dropped",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ Indexes: []schema.Index{
+ {
+ TableName: "foobar",
+ Name: "foobar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_pkey ON public.foobar USING btree (id)",
+ },
+ {
+ TableName: "foobar",
+ Name: "some_idx", Columns: []string{"foo, bar"},
+ GetIndexDefStmt: "CREATE INDEX some_idx ON public.foobar USING btree (foo, bar)",
+ },
+ },
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ Indexes: []schema.Index{
+ {
+ TableName: "foobar",
+ Name: "foobar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_pkey ON public.foobar USING btree (id)",
+ },
+ },
+ },
+ expectedStatements: []Statement{
+ {
+ DDL: "DROP INDEX CONCURRENTLY \"some_idx\"",
+ Timeout: statementTimeoutConcurrentIndexDrop,
+ Hazards: []MigrationHazard{buildIndexDroppedQueryPerfHazard()},
+ },
+ {
+ DDL: "ALTER TABLE \"foobar\" DROP COLUMN \"bar\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{buildColumnDataDeletionHazard()},
+ },
+ {
+ DDL: "ALTER TABLE \"foobar\" DROP COLUMN \"foo\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{buildColumnDataDeletionHazard()},
+ },
+ },
+ },
+ {
+ name: "Invalid index re-created",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ Indexes: []schema.Index{
+ {
+ TableName: "foobar",
+ Name: "some_idx", Columns: []string{"foo", "bar"},
+ GetIndexDefStmt: "CREATE INDEX some_idx ON public.foobar USING btree (foo, bar)",
+ IsUnique: true, IsInvalid: true,
+ },
+ },
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ Indexes: []schema.Index{
+
+ {
+ TableName: "foobar",
+ Name: "some_idx", Columns: []string{"foo", "bar"},
+ GetIndexDefStmt: "CREATE INDEX some_idx ON public.foobar USING btree (foo, bar)",
+ IsUnique: true,
+ },
+ },
+ },
+ expectedStatements: []Statement{
+ {
+ DDL: "ALTER INDEX \"some_idx\" RENAME TO \"some_idx_10111213-1415-4617-9819-1a1b1c1d1e1f\"",
+ Timeout: statementTimeoutDefault,
+ },
+ {
+ DDL: "CREATE INDEX CONCURRENTLY some_idx ON public.foobar USING btree (foo, bar)",
+ Timeout: statementTimeoutConcurrentIndexBuild,
+ Hazards: []MigrationHazard{buildIndexBuildHazard()},
+ },
+ {
+ DDL: "DROP INDEX CONCURRENTLY \"some_idx_10111213-1415-4617-9819-1a1b1c1d1e1f\"",
+ Timeout: statementTimeoutConcurrentIndexDrop,
+ Hazards: []MigrationHazard{buildIndexDroppedQueryPerfHazard()},
+ },
+ },
+ },
+ {
+ name: "Index replacement on partitioned table",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ PartitionKeyDef: "PARTITION BY LIST(foo)",
+ },
+ {
+ ParentTableName: "foobar",
+ Name: "foobar_1",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some_val')",
+ },
+ {
+ ParentTableName: "foobar",
+ Name: "foobar_2",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some_other_val')",
+ },
+ },
+ Indexes: []schema.Index{
+ // foobar indexes
+ {
+ TableName: "foobar",
+ Name: "some_idx", Columns: []string{"foo", "bar"},
+ GetIndexDefStmt: "CREATE INDEX some_idx ON ONLY public.foobar USING btree (foo, bar)",
+ },
+ {
+ TableName: "foobar",
+ Name: "replaced_with_same_name_idx", Columns: []string{"bar"},
+ GetIndexDefStmt: "CREATE INDEX replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)",
+ },
+ // foobar_1 indexes
+ {
+ TableName: "foobar_1",
+ Name: "foobar_1_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "some_idx",
+ GetIndexDefStmt: "CREATE INDEX foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)",
+ },
+ {
+ TableName: "foobar_1",
+ Name: "foobar_1_replaced_with_same_name_idx", Columns: []string{"bar"}, ParentIdxName: "replaced_with_same_name_idx",
+ GetIndexDefStmt: "CREATE INDEX foobar_1_replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)",
+ },
+ {
+ TableName: "foobar_1",
+ Name: "foobar_1_some_local_idx", Columns: []string{"foo", "bar", "id"},
+ GetIndexDefStmt: "CREATE INDEX foobar_1_some_local_idx ON public.foobar_1 USING btree (foo, bar, id)",
+ },
+ // foobar_2 indexes
+ {
+ TableName: "foobar_2",
+ Name: "foobar_2_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "some_idx",
+ GetIndexDefStmt: "CREATE INDEX foobar_2_some_idx ON public.foobar_2 USING btree (foo, bar)",
+ },
+ {
+ TableName: "foobar_2",
+ Name: "foobar_2_replaced_with_same_name_idx", Columns: []string{"bar"}, ParentIdxName: "replaced_with_same_name_idx",
+ GetIndexDefStmt: "CREATE INDEX foobar_2_replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)",
+ },
+ },
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ PartitionKeyDef: "PARTITION BY LIST(foo)",
+ },
+ {
+ ParentTableName: "foobar",
+ Name: "foobar_1",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some_val')",
+ },
+ {
+ ParentTableName: "foobar",
+ Name: "foobar_2",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some_other_val')",
+ },
+ },
+ Indexes: []schema.Index{
+ // foobar indexes
+ {
+ TableName: "foobar",
+ Name: "new_some_idx", Columns: []string{"foo", "bar"},
+ GetIndexDefStmt: "CREATE INDEX new_some_idx ON ONLY public.foobar USING btree (foo, bar)",
+ },
+ {
+ TableName: "foobar",
+ Name: "replaced_with_same_name_idx", Columns: []string{"bar", "foo"},
+ GetIndexDefStmt: "CREATE INDEX replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar, foo)",
+ },
+ // foobar_1 indexes
+ {
+ TableName: "foobar_1",
+ Name: "new_foobar_1_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "new_some_idx",
+ GetIndexDefStmt: "CREATE INDEX new_foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)",
+ },
+ {
+ TableName: "foobar_1",
+ Name: "foobar_1_replaced_with_same_name_idx", Columns: []string{"bar", "foo"}, ParentIdxName: "replaced_with_same_name_idx",
+ GetIndexDefStmt: "CREATE INDEX foobar_1_replaced_with_same_name_idx ON public.foobar USING btree (bar, foo)",
+ },
+ {
+ TableName: "foobar_1",
+ Name: "new_foobar_1_some_local_idx", Columns: []string{"foo", "bar", "id"},
+ GetIndexDefStmt: "CREATE INDEX new_foobar_1_some_local_idx ON public.foobar_1 USING btree (foo, bar, id)",
+ },
+ // foobar_2 indexes
+ {
+ TableName: "foobar_2",
+ Name: "new_foobar_2_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "new_some_idx",
+ GetIndexDefStmt: "CREATE INDEX new_foobar_2_some_idx ON public.foobar_2 USING btree (foo, bar)",
+ },
+ {
+ TableName: "foobar_2",
+ Name: "foobar_2_replaced_with_same_name_idx", Columns: []string{"bar", "foo"}, ParentIdxName: "replaced_with_same_name_idx",
+ GetIndexDefStmt: "CREATE INDEX foobar_2_replaced_with_same_name_idx ON public.foobar_2 USING btree (bar, foo)",
+ },
+ },
+ },
+ expectedStatements: []Statement{
+ {
+ DDL: "ALTER INDEX \"foobar_1_replaced_with_same_name_idx\" RENAME TO \"foobar_1_replaced_with_sam_30313233-3435-4637-b839-3a3b3c3d3e3f\"",
+ Timeout: statementTimeoutDefault,
+ },
+ {
+ DDL: "ALTER INDEX \"foobar_2_replaced_with_same_name_idx\" RENAME TO \"foobar_2_replaced_with_sam_40414243-4445-4647-8849-4a4b4c4d4e4f\"",
+ Timeout: statementTimeoutDefault,
+ },
+ {
+ DDL: "ALTER INDEX \"replaced_with_same_name_idx\" RENAME TO \"replaced_with_same_name_id_20212223-2425-4627-a829-2a2b2c2d2e2f\"",
+ Timeout: statementTimeoutDefault,
+ },
+ {
+ DDL: "CREATE INDEX new_some_idx ON ONLY public.foobar USING btree (foo, bar)",
+ Timeout: statementTimeoutDefault,
+ },
+ {
+ DDL: "CREATE INDEX replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar, foo)",
+ Timeout: statementTimeoutDefault,
+ },
+ {
+ DDL: "CREATE INDEX CONCURRENTLY foobar_1_replaced_with_same_name_idx ON public.foobar USING btree (bar, foo)",
+ Timeout: statementTimeoutConcurrentIndexDrop,
+ Hazards: []MigrationHazard{
+ buildIndexBuildHazard(),
+ },
+ },
+ {
+ DDL: "ALTER INDEX \"replaced_with_same_name_idx\" ATTACH PARTITION \"foobar_1_replaced_with_same_name_idx\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: nil,
+ },
+ {
+ DDL: "CREATE INDEX CONCURRENTLY new_foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)",
+ Timeout: statementTimeoutConcurrentIndexBuild,
+ Hazards: []MigrationHazard{
+ buildIndexBuildHazard(),
+ },
+ },
+ {
+ DDL: "ALTER INDEX \"new_some_idx\" ATTACH PARTITION \"new_foobar_1_some_idx\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: nil,
+ },
+ {
+ DDL: "CREATE INDEX CONCURRENTLY new_foobar_1_some_local_idx ON public.foobar_1 USING btree (foo, bar, id)",
+ Timeout: statementTimeoutConcurrentIndexBuild,
+ Hazards: []MigrationHazard{
+ buildIndexBuildHazard(),
+ },
+ },
+ {
+ DDL: "CREATE INDEX CONCURRENTLY foobar_2_replaced_with_same_name_idx ON public.foobar_2 USING btree (bar, foo)",
+ Timeout: statementTimeoutConcurrentIndexBuild,
+ Hazards: []MigrationHazard{
+ buildIndexBuildHazard(),
+ },
+ },
+ {
+ DDL: "ALTER INDEX \"replaced_with_same_name_idx\" ATTACH PARTITION \"foobar_2_replaced_with_same_name_idx\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: nil,
+ },
+ {
+ DDL: "CREATE INDEX CONCURRENTLY new_foobar_2_some_idx ON public.foobar_2 USING btree (foo, bar)",
+ Timeout: statementTimeoutConcurrentIndexBuild,
+ Hazards: []MigrationHazard{
+ buildIndexBuildHazard(),
+ },
+ },
+ {
+ DDL: "ALTER INDEX \"new_some_idx\" ATTACH PARTITION \"new_foobar_2_some_idx\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: nil,
+ },
+ {
+ DDL: "DROP INDEX CONCURRENTLY \"foobar_1_some_local_idx\"",
+ Timeout: statementTimeoutConcurrentIndexDrop,
+ Hazards: []MigrationHazard{
+ buildIndexDroppedQueryPerfHazard(),
+ },
+ },
+ {
+ DDL: "DROP INDEX \"replaced_with_same_name_id_20212223-2425-4627-a829-2a2b2c2d2e2f\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{
+ buildIndexDroppedAcquiresLockHazard(),
+ buildIndexDroppedQueryPerfHazard(),
+ },
+ },
+ {
+ DDL: "DROP INDEX \"some_idx\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{
+ buildIndexDroppedAcquiresLockHazard(),
+ buildIndexDroppedQueryPerfHazard(),
+ },
+ },
+ },
+ },
+ {
+ name: "Local Index dropped concurrently before columns dropped; partitioned index just dropped",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ PartitionKeyDef: "PARTITION BY LIST(foo)",
+ },
+ {
+ ParentTableName: "foobar",
+ Name: "foobar_1",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some_val')",
+ },
+ },
+ Indexes: []schema.Index{
+ // foobar indexes
+ {
+ TableName: "foobar",
+ Name: "foobar_pkey", Columns: []string{"foo", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_pkey ON ONLY public.foobar USING btree (foo, id)",
+ },
+ {
+ TableName: "foobar",
+ Name: "some_idx", Columns: []string{"foo, bar"},
+ GetIndexDefStmt: "CREATE INDEX some_idx ON ONLY public.foobar USING btree (foo, bar)",
+ },
+ // foobar_1 indexes
+ {
+ TableName: "foobar_1",
+ Name: "foobar_1_pkey", Columns: []string{"foo", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey", ParentIdxName: "foobar_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_1_pkey ON public.foobar_1 USING btree (foo, id)",
+ },
+ {
+ TableName: "foobar_1",
+ Name: "foobar_1_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "some_idx",
+ GetIndexDefStmt: "CREATE INDEX foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)",
+ },
+ {
+ TableName: "foobar_1",
+ Name: "foobar_1_some_local_idx", Columns: []string{"foo", "bar", "id"},
+ GetIndexDefStmt: "CREATE INDEX foobar_1_some_local_idx ON public.foobar_1 USING btree (foo, bar, id)",
+ },
+ },
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ },
+ CheckConstraints: nil,
+ PartitionKeyDef: "PARTITION BY LIST(foo)",
+ },
+ {
+ ParentTableName: "foobar",
+ Name: "foobar_1",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some_val')",
+ },
+ },
+ Indexes: []schema.Index{
+ // foobar indexes
+ {
+ TableName: "foobar",
+ Name: "foobar_pkey", Columns: []string{"foo", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_pkey ON ONLY public.foobar USING btree (foo, id)",
+ },
+ // foobar_1 indexes
+ {
+ TableName: "foobar_1",
+ Name: "foobar_1_pkey", Columns: []string{"foo", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey", ParentIdxName: "foobar_pkey",
+ GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_1_pkey ON public.foobar_1 USING btree (foo, id)",
+ },
+ },
+ },
+ expectedStatements: []Statement{
+ {
+ DDL: "DROP INDEX CONCURRENTLY \"foobar_1_some_local_idx\"",
+ Timeout: statementTimeoutConcurrentIndexDrop,
+ Hazards: []MigrationHazard{
+ buildIndexDroppedQueryPerfHazard(),
+ },
+ },
+ {
+ DDL: "DROP INDEX \"some_idx\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{
+ buildIndexDroppedAcquiresLockHazard(),
+ buildIndexDroppedQueryPerfHazard(),
+ },
+ },
+ {
+ DDL: "ALTER TABLE \"foobar\" DROP COLUMN \"bar\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{
+ buildColumnDataDeletionHazard(),
+ },
+ },
+ },
+ },
+ {
+ name: "Invalid index of partitioned index re-created but original index remains untouched",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ PartitionKeyDef: "PARTITION BY LIST(foo)",
+ },
+ {
+ ParentTableName: "foobar",
+ Name: "foobar_1",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some_val')",
+ },
+ },
+ Indexes: []schema.Index{
+ // foobar indexes
+ {
+ TableName: "foobar",
+ Name: "some_idx", Columns: []string{"foo, bar"},
+ GetIndexDefStmt: "CREATE INDEX some_idx ON ONLY public.foobar USING btree (foo, bar)",
+ IsInvalid: true,
+ },
+ // foobar_1 indexes
+ {
+ TableName: "foobar_1",
+ Name: "foobar_1_some_idx", Columns: []string{"foo", "bar"},
+ GetIndexDefStmt: "CREATE INDEX foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)",
+ IsInvalid: true,
+ },
+ },
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ PartitionKeyDef: "PARTITION BY LIST(foo)",
+ },
+ {
+ ParentTableName: "foobar",
+ Name: "foobar_1",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"},
+ },
+ CheckConstraints: nil,
+ ForValues: "FOR VALUES IN ('some_val')",
+ },
+ },
+ Indexes: []schema.Index{
+ // foobar indexes
+ {
+ TableName: "foobar",
+ Name: "some_idx", Columns: []string{"foo, bar"},
+ GetIndexDefStmt: "CREATE INDEX some_idx ON ONLY public.foobar USING btree (foo, bar)",
+ },
+ // foobar_1 indexes
+ {
+ TableName: "foobar_1",
+ Name: "foobar_1_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "some_idx",
+ GetIndexDefStmt: "CREATE INDEX foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)",
+ },
+ },
+ },
+ expectedStatements: []Statement{
+ {
+ DDL: "ALTER INDEX \"foobar_1_some_idx\" RENAME TO \"foobar_1_some_idx_50515253-5455-4657-9859-5a5b5c5d5e5f\"",
+ Timeout: statementTimeoutDefault,
+ },
+ {
+ DDL: "CREATE INDEX CONCURRENTLY foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)",
+ Timeout: statementTimeoutConcurrentIndexBuild,
+ Hazards: []MigrationHazard{
+ buildIndexBuildHazard(),
+ },
+ },
+ {
+ DDL: "ALTER INDEX \"some_idx\" ATTACH PARTITION \"foobar_1_some_idx\"",
+ Timeout: statementTimeoutDefault,
+ },
+ {
+ DDL: "DROP INDEX CONCURRENTLY \"foobar_1_some_idx_50515253-5455-4657-9859-5a5b5c5d5e5f\"",
+ Timeout: statementTimeoutConcurrentIndexDrop,
+ Hazards: []MigrationHazard{
+ buildIndexDroppedQueryPerfHazard(),
+ },
+ },
+ },
+ },
+ {
+ name: "Fails on duplicate column in old schema",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "id", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ {Name: "something", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ },
+ expectedStatements: nil,
+ expectedDiffErrIs: errDuplicateIdentifier,
+ },
+ {
+ name: "Invalid check constraint made valid",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "id_check", Expression: "(id > 0)", IsInheritable: true},
+ },
+ },
+ },
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "id_check", Expression: "(id > 0)", IsInheritable: true, IsValid: true},
+ },
+ },
+ },
+ },
+ expectedStatements: []Statement{
+ {
+ DDL: "ALTER TABLE \"foobar\" VALIDATE CONSTRAINT \"id_check\"",
+ Timeout: statementTimeoutDefault,
+ Hazards: nil,
+ },
+ },
+ },
+ {
+ name: "Invalid check constraint re-created if expression changes",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "id_check", Expression: "(id > 0)", IsInheritable: true},
+ },
+ },
+ },
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "id", Type: "integer"},
+ },
+ CheckConstraints: []schema.CheckConstraint{
+ {Name: "id_check", Expression: "(id < 0)", IsInheritable: true, IsValid: true},
+ },
+ },
+ },
+ },
+ expectedStatements: []Statement{
+ {
+ DDL: "ALTER TABLE \"foobar\" DROP CONSTRAINT \"id_check\"",
+ Timeout: statementTimeoutDefault,
+ },
+ {
+ DDL: "ALTER TABLE \"foobar\" ADD CONSTRAINT \"id_check\" CHECK((id < 0))",
+ Timeout: statementTimeoutDefault,
+ },
+ },
+ },
+ {
+ name: "BIGINT to TIMESTAMP type conversion",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "baz", Type: "bigint"},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ Indexes: nil,
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "baz", Type: "timestamp without time zone", Default: "current_timestamp"},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ Indexes: nil,
+ },
+ expectedStatements: []Statement{
+ {
+ DDL: "ALTER TABLE \"foobar\" ALTER COLUMN \"baz\" SET DATA TYPE timestamp without time zone using to_timestamp(\"baz\" / 1000)",
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{{
+ Type: MigrationHazardTypeAcquiresAccessExclusiveLock,
+ Message: "This will completely lock the table while the data is being " +
+ "re-written for a duration of time that scales with the size of your " +
+ "data. The values previously stored as BIGINT will be translated into a " +
+ "TIMESTAMP value via the PostgreSQL to_timestamp() function. This " +
+ "translation will assume that the values stored in BIGINT represent a " +
+ "millisecond epoch value.",
+ }},
+ },
+ {
+ DDL: "ANALYZE \"foobar\" (\"baz\")",
+ Timeout: statementTimeoutAnalyzeColumn,
+ Hazards: []MigrationHazard{buildAnalyzeColumnMigrationHazard()},
+ },
+ {
+ DDL: "ALTER TABLE \"foobar\" ALTER COLUMN \"baz\" SET DEFAULT current_timestamp",
+ Timeout: statementTimeoutDefault,
+ },
+ },
+ },
+ {
+ name: "Collation migration and Type Migration",
+ oldSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "migrate_to_c_coll", Type: "text", Collation: defaultCollation},
+ {Name: "migrate_type", Type: "text", Collation: defaultCollation},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ },
+ newSchema: schema.Schema{
+ Tables: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "migrate_to_c_coll", Type: "text", Collation: cCollation},
+ {Name: "migrate_type", Type: "character varying(255)", Collation: defaultCollation},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ },
+ expectedStatements: []Statement{
+ {
+ DDL: "ALTER TABLE \"foobar\" ALTER COLUMN \"migrate_to_c_coll\" SET DATA TYPE text COLLATE \"pg_catalog\".\"C\" using \"migrate_to_c_coll\"::text",
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{buildColumnTypeChangeHazard()},
+ },
+ {
+ DDL: "ANALYZE \"foobar\" (\"migrate_to_c_coll\")",
+ Timeout: statementTimeoutAnalyzeColumn,
+ Hazards: []MigrationHazard{buildAnalyzeColumnMigrationHazard()},
+ },
+ {
+ DDL: "ALTER TABLE \"foobar\" ALTER COLUMN \"migrate_type\" SET DATA TYPE character varying(255) COLLATE \"pg_catalog\".\"default\" using \"migrate_type\"::character varying(255)",
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{buildColumnTypeChangeHazard()},
+ },
+ {
+ DDL: "ANALYZE \"foobar\" (\"migrate_type\")",
+ Timeout: statementTimeoutAnalyzeColumn,
+ Hazards: []MigrationHazard{buildAnalyzeColumnMigrationHazard()},
+ },
+ },
+ },
+ }
+)
+
+type deterministicRandReader struct {
+ counter int8
+}
+
+func (r *deterministicRandReader) Read(p []byte) (int, error) {
+ for i := 0; i < len(p); i++ {
+ p[i] = byte(r.counter)
+ r.counter++
+ }
+ return len(p), nil
+}
+
+func TestSchemaMigrationPlanTest(t *testing.T) {
+ uuid.SetRand(&deterministicRandReader{})
+
+ for _, testCase := range schemaMigrationPlanTestCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ schemaDiff, _, err := buildSchemaDiff(testCase.oldSchema, testCase.newSchema)
+ if testCase.expectedDiffErrIs != nil {
+ require.ErrorIs(t, err, testCase.expectedDiffErrIs)
+ } else {
+ require.NoError(t, err)
+ }
+ stmts, err := schemaSQLGenerator{}.Alter(schemaDiff)
+ require.NoError(t, err)
+ assert.Equal(t, testCase.expectedStatements, stmts, "actual:\n %# v", pretty.Formatter(stmts))
+ })
+ }
+}
+
+func buildColumnDataDeletionHazard() MigrationHazard {
+ return MigrationHazard{
+ Type: MigrationHazardTypeDeletesData,
+ Message: "Deletes all values in the column",
+ }
+}
+
+func buildColumnTypeChangeHazard() MigrationHazard {
+ return MigrationHazard{
+ Type: MigrationHazardTypeAcquiresAccessExclusiveLock,
+ Message: "This will completely lock the table while the data is being re-written. The duration of this " +
+ "conversion depends on if the type conversion is trivial or not. A non-trivial conversion will require a " +
+ "table rewrite. A trivial conversion is one where the binary values are coercible and the column contents " +
+ "are not changing.",
+ }
+}
+
+func buildAnalyzeColumnMigrationHazard() MigrationHazard {
+ return MigrationHazard{
+ Type: MigrationHazardTypeImpactsDatabasePerformance,
+ Message: "Running analyze will read rows from the table, putting increased load " +
+ "on the database and consuming database resources. It won't prevent reads/writes to " +
+ "the table, but it could affect performance when executing queries.",
+ }
+}
+
+func buildIndexBuildHazard() MigrationHazard {
+ return MigrationHazard{
+ Type: MigrationHazardTypeIndexBuild,
+ Message: "This might affect database performance. " +
+ "Concurrent index builds require a non-trivial amount of CPU, potentially affecting database performance. " +
+ "They also can take a while but do not lock out writes.",
+ }
+}
+
+func buildIndexDroppedQueryPerfHazard() MigrationHazard {
+ return MigrationHazard{
+ Type: MigrationHazardTypeIndexDropped,
+ Message: "Dropping this index means queries that use this index might perform worse because " +
+ "they will no longer will be able to leverage it.",
+ }
+}
+
+func buildIndexDroppedAcquiresLockHazard() MigrationHazard {
+ return MigrationHazard{
+ Type: MigrationHazardTypeAcquiresAccessExclusiveLock,
+ Message: "Index drops will lock out all accesses to the table. They should be fast",
+ }
+}
diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go
new file mode 100644
index 0000000..ae03298
--- /dev/null
+++ b/pkg/diff/sql_generator.go
@@ -0,0 +1,1431 @@
+package diff
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/uuid"
+ "github.com/stripe/pg-schema-diff/internal/schema"
+)
+
+const (
+ maxPostgresIdentifierSize = 63
+
+ statementTimeoutDefault = 3 * time.Second
+ // statementTimeoutConcurrentIndexBuild is the statement timeout for index builds. It may take a while to build
+ // the index. Since it doesn't take out locks, this shouldn't be a concern
+ statementTimeoutConcurrentIndexBuild = 20 * time.Minute
+ // statementTimeoutConcurrentIndexDrop is the statement timeout for concurrent index drops. This operation shouldn't
+ // take out locks except when changing table metadata, but it may take a while to complete, so give it a long
+ // timeout
+ statementTimeoutConcurrentIndexDrop = 20 * time.Minute
+ // statementTimeoutTableDrop is the statement timeout for table drops. It may a take a while to delete the data
+ // Since the table is being dropped, locks shouldn't be a concern
+ statementTimeoutTableDrop = 20 * time.Minute
+ // statementTimeoutAnalyzeColumn is the statement timeout for analyzing the column of a table
+ statementTimeoutAnalyzeColumn = 20 * time.Minute
+)
+
+var (
+ ErrColumnOrderingChanged = fmt.Errorf("column ordering changed: %w", ErrNotImplemented)
+
+ migrationHazardAddAlterFunctionCannotTrackDependencies = MigrationHazard{
+ Type: MigrationHazardTypeHasUntrackableDependencies,
+ Message: "Dependencies, i.e. other functions used in the function body, of non-sql functions cannot be tracked. " +
+ "As a result, we cannot guarantee that function dependencies are ordered properly relative to this " +
+ "statement. For adds, this means you need to ensure that all functions this function depends on are " +
+ "created/altered before this statement.",
+ }
+ migrationHazardIndexDroppedQueryPerf = MigrationHazard{
+ Type: MigrationHazardTypeIndexDropped,
+ Message: "Dropping this index means queries that use this index might perform worse because " +
+ "they will no longer will be able to leverage it.",
+ }
+ migrationHazardIndexDroppedAcquiresLock = MigrationHazard{
+ Type: MigrationHazardTypeAcquiresAccessExclusiveLock,
+ Message: "Index drops will lock out all accesses to the table. They should be fast",
+ }
+)
+
+type oldAndNew[S schema.Object] struct {
+ old S
+ new S
+}
+
+func (o oldAndNew[S]) GetNew() S {
+ return o.old
+}
+
+func (o oldAndNew[S]) GetOld() S {
+ return o.new
+}
+
+type (
+ columnDiff struct {
+ oldAndNew[schema.Column]
+ oldOrdering int
+ newOrdering int
+ }
+
+ checkConstraintDiff struct {
+ oldAndNew[schema.CheckConstraint]
+ }
+
+ tableDiff struct {
+ oldAndNew[schema.Table]
+ columnsDiff listDiff[schema.Column, columnDiff]
+ checkConstraintDiff listDiff[schema.CheckConstraint, checkConstraintDiff]
+ }
+
+ indexDiff struct {
+ oldAndNew[schema.Index]
+ }
+
+ functionDiff struct {
+ oldAndNew[schema.Function]
+ }
+
+ triggerDiff struct {
+ oldAndNew[schema.Trigger]
+ }
+)
+
+type schemaDiff struct {
+ oldAndNew[schema.Schema]
+ tableDiffs listDiff[schema.Table, tableDiff]
+ indexDiffs listDiff[schema.Index, indexDiff]
+ functionDiffs listDiff[schema.Function, functionDiff]
+ triggerDiffs listDiff[schema.Trigger, triggerDiff]
+}
+
+func (sd schemaDiff) resolveToSQL() ([]Statement, error) {
+ return schemaSQLGenerator{}.Alter(sd)
+}
+
+// The procedure for DIFFING schemas and GENERATING/RESOLVING the SQL required to migrate the old schema to the new schema is
+// described below:
+//
+// A schema follows a hierarchy: Schemas -> Tables -> Columns and Indexes
+// Every level of the hierarchy can depend on other items at the same level (indexes depend on columns).
+// A similar idea applies with constraints, including Foreign key constraints. Because constraints can have cross-table
+// dependencies, they can be viewed at the same level as tables. This hierarchy becomes interwoven with partitions
+//
+// Diffing two sets of schema objects follows a common pattern:
+// (DIFFING)
+// 1. Diff two lists of schema objects (e.g., schemas, tables). An item is new if it's name is not present in the old list.
+// An item is deleted if it's name is not present in the new list. Otherwise, an item might have been altered
+// 2. For each potentially altered item, generate the diff between the old and new. This might involve diffing lists if they have
+// nested items (recursing into step 1)
+// (GENERATING/RESOLVING)
+// 3. Generate the SQL required for the deleted items (ADDS), the new items, and the altered items.
+// These items might have interwoven dependencies
+// 4. Topologically sort the diffed items
+//
+// The diffing is handled by diffLists. Diffs lists takes two lists of schema objects identifies which are
+// added, deleted, or potentially altered. If the items are potentially altered, it will pass the items
+// to a callback which handles diffing the old and new versions. This callback might call into diff lists
+// for items nested inside its hierarchy
+//
+// Generating the SQL for the resulting diff from the two lists of items is handled by the SQL(Vertex)Generators.
+// Every schema object defines a SQL(Vertex)Generator. A SQL(Vertex)Generator generates the SQL required to add, delete,
+// or alter a schema object. If altering a schema object, the SQL(Vertex)Generator is passed the diff generated by the callback in diffLists.
+// The sqlGenerator just generates SQL, while the sqlVertexGenerator also defines dependencies that a schema object has
+// on other schema objects
+
+func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) {
+ tableDiffs, err := diffLists(old.Tables, new.Tables, buildTableDiff)
+ if err != nil {
+ return schemaDiff{}, false, fmt.Errorf("diffing tables: %w", err)
+ }
+
+ newSchemaTablesByName := buildSchemaObjMap(new.Tables)
+ addedTablesByName := buildSchemaObjMap(tableDiffs.adds)
+ indexesDiff, err := diffLists(old.Indexes, new.Indexes, func(old, new schema.Index, oldIndex, newIndex int) (indexDiff, bool, error) {
+ return buildIndexDiff(newSchemaTablesByName, addedTablesByName, old, new, oldIndex, newIndex)
+ })
+ if err != nil {
+ return schemaDiff{}, false, fmt.Errorf("diffing indexes: %w", err)
+ }
+
+ functionDiffs, err := diffLists(old.Functions, new.Functions, func(old, new schema.Function, _, _ int) (functionDiff, bool, error) {
+ return functionDiff{
+ oldAndNew[schema.Function]{
+ old: old,
+ new: new,
+ },
+ }, false, nil
+ })
+ if err != nil {
+ return schemaDiff{}, false, fmt.Errorf("diffing functions: %w", err)
+ }
+
+ triggerDiffs, err := diffLists(old.Triggers, new.Triggers, func(old, new schema.Trigger, _, _ int) (triggerDiff, bool, error) {
+ if _, isOnNewTable := addedTablesByName[new.OwningTableUnescapedName]; isOnNewTable {
+ // If the table is new, then it must be re-created (this occurs if the base table has been
+ // re-created). In other words, a trigger must be re-created if the owning table is re-created
+ return triggerDiff{}, true, nil
+ }
+ return triggerDiff{
+ oldAndNew[schema.Trigger]{
+ old: old,
+ new: new,
+ },
+ }, false, nil
+ })
+ if err != nil {
+ return schemaDiff{}, false, fmt.Errorf("diffing triggers: %w", err)
+ }
+
+ return schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{
+ old: old,
+ new: new,
+ },
+ tableDiffs: tableDiffs,
+ indexDiffs: indexesDiff,
+ functionDiffs: functionDiffs,
+ triggerDiffs: triggerDiffs,
+ }, false, nil
+}
+
+func buildTableDiff(oldTable, newTable schema.Table, _, _ int) (diff tableDiff, requiresRecreation bool, err error) {
+ if oldTable.IsPartitioned() != newTable.IsPartitioned() {
+ return tableDiff{}, true, nil
+ } else if oldTable.PartitionKeyDef != newTable.PartitionKeyDef {
+ // We won't support changing partition key def due to issues with requiresRecreation.
+ //
+ // BLUF of the problem: If you have a flattened hierarchy (partitions, materialized views) and the parent
+ // is re-created but the children are unchanged, the children need to be re-created.
+ //
+ // If we want to add support, then we need diffLists to identify if a parent has been re-created (or if parents have changed),
+ // so it knows to re-create the child. This problem becomes more acute when a child can belong to
+ // multiple parents, e.g., materialized views. Ultimately, it's a graph problem in diffLists that can
+ // be solved through a `getParents` function
+ //
+ // Until the above is implemented, we can't support requiresRecreation on any flattened hierarchies
+ return tableDiff{}, false, fmt.Errorf("changing partition key def: %w", ErrNotImplemented)
+ }
+
+ if oldTable.ParentTableName != newTable.ParentTableName {
+ // Since diffLists doesn't handle re-creating hierarchies that change, we need to manually
+ // identify if the hierarchy has changed. This approach will NOT work if we support multiple layers
+ // of partitioning because it's possible the parent's parent changed but the parent remained the same
+ return tableDiff{}, true, nil
+ }
+
+ columnsDiff, err := diffLists(
+ oldTable.Columns,
+ newTable.Columns,
+ func(old, new schema.Column, oldIndex, newIndex int) (columnDiff, bool, error) {
+ return columnDiff{
+ oldAndNew: oldAndNew[schema.Column]{old: old, new: new},
+ oldOrdering: oldIndex,
+ newOrdering: newIndex,
+ }, false, nil
+ },
+ )
+ if err != nil {
+ return tableDiff{}, false, fmt.Errorf("diffing columns: %w", err)
+ }
+
+ checkConsDiff, err := diffLists(
+ oldTable.CheckConstraints,
+ newTable.CheckConstraints,
+ func(old, new schema.CheckConstraint, _, _ int) (checkConstraintDiff, bool, error) {
+ recreateConstraint := (old.Expression != new.Expression) ||
+ (old.IsValid && !new.IsValid) ||
+ (old.IsInheritable != new.IsInheritable)
+ return checkConstraintDiff{oldAndNew[schema.CheckConstraint]{old: old, new: new}},
+ recreateConstraint,
+ nil
+ },
+ )
+ if err != nil {
+ return tableDiff{}, false, fmt.Errorf("diffing lists: %w", err)
+ }
+
+ return tableDiff{
+ oldAndNew: oldAndNew[schema.Table]{
+ old: oldTable,
+ new: newTable,
+ },
+ columnsDiff: columnsDiff,
+ checkConstraintDiff: checkConsDiff,
+ }, false, nil
+}
+
+// buildIndexDiff builds the index diff
+func buildIndexDiff(newSchemaTablesByName map[string]schema.Table, addedTablesByName map[string]schema.Table, old, new schema.Index, _, _ int) (diff indexDiff, requiresRecreation bool, err error) {
+ updatedOld := old
+
+ if _, isOnNewTable := addedTablesByName[new.TableName]; isOnNewTable {
+ // If the table is new, then it must be re-created (this occurs if the base table has been
+ // re-created). In other words, an index must be re-created if the owning table is re-created
+ return indexDiff{}, true, nil
+ }
+
+ if len(old.ParentIdxName) == 0 {
+ // If the old index didn't belong to a partitioned index (and the new index does), we can resolve the parent
+ // index name diff if the index now belongs to a partitioned index by attaching the index.
+ // We can't switch an index partition from one parent to another; in that instance, we must
+ // re-create the index
+ updatedOld.ParentIdxName = new.ParentIdxName
+ }
+
+ if !new.IsPartitionOfIndex() && !old.IsPk && new.IsPk {
+ // If the old index is not part of a primary key and the new index is part of a primary key,
+ // the constraint name diff is resolvable by adding the index to the primary key.
+ // Partitioned indexes are the exception; for partitioned indexes that are
+ // primary keys, the indexes are created with the constraint on the base table and cannot
+ // be attached to the base index
+ // In the future, we can change this behavior to ONLY create the constraint on the base table
+ // and follow a similar paradigm to adding indexes
+ updatedOld.ConstraintName = new.ConstraintName
+ updatedOld.IsPk = new.IsPk
+ }
+
+ if isOnPartitionedTable, err := isOnPartitionedTable(newSchemaTablesByName, new); err != nil {
+ return indexDiff{}, false, err
+ } else if isOnPartitionedTable && old.IsInvalid && !new.IsInvalid {
+ // If the index is a partitioned index, it can be made valid automatically by attaching the index partitions
+ // We don't need to re-create it.
+ updatedOld.IsInvalid = new.IsInvalid
+ }
+
+ recreateIndex := !cmp.Equal(updatedOld, new)
+ return indexDiff{
+ oldAndNew: oldAndNew[schema.Index]{
+ old: old, new: new,
+ },
+ }, recreateIndex, nil
+}
+
+type schemaSQLGenerator struct{}
+
+func (schemaSQLGenerator) Alter(diff schemaDiff) ([]Statement, error) {
+ tablesInNewSchemaByName := buildSchemaObjMap(diff.new.Tables)
+ deletedTablesByName := buildSchemaObjMap(diff.tableDiffs.deletes)
+
+ tableSQLVertexGenerator := tableSQLVertexGenerator{
+ deletedTablesByName: deletedTablesByName,
+ tablesInNewSchemaByName: tablesInNewSchemaByName,
+ }
+ tableGraphs, err := diff.tableDiffs.resolveToSQLGraph(&tableSQLVertexGenerator)
+ if err != nil {
+ return nil, fmt.Errorf("resolving table sql graphs: %w", err)
+ }
+
+ indexesInNewSchemaByTableName := make(map[string][]schema.Index)
+ for _, idx := range diff.new.Indexes {
+ indexesInNewSchemaByTableName[idx.TableName] = append(indexesInNewSchemaByTableName[idx.TableName], idx)
+ }
+ attachPartitionSQLVertexGenerator := attachPartitionSQLVertexGenerator{
+ indexesInNewSchemaByTableName: indexesInNewSchemaByTableName,
+ }
+ attachPartitionGraphs, err := diff.tableDiffs.resolveToSQLGraph(&attachPartitionSQLVertexGenerator)
+ if err != nil {
+ return nil, fmt.Errorf("resolving attach partition sql graphs: %w", err)
+ }
+
+ renameConflictingIndexSQLVertexGenerator := newRenameConflictingIndexSQLVertexGenerator(buildSchemaObjMap(diff.old.Indexes))
+ renameConflictingIndexGraphs, err := diff.indexDiffs.resolveToSQLGraph(&renameConflictingIndexSQLVertexGenerator)
+ if err != nil {
+ return nil, fmt.Errorf("resolving renaming conflicting indexes: %w", err)
+ }
+
+ indexSQLVertexGenerator := indexSQLVertexGenerator{
+ deletedTablesByName: deletedTablesByName,
+ addedTablesByName: buildSchemaObjMap(diff.tableDiffs.adds),
+ tablesInNewSchemaByName: tablesInNewSchemaByName,
+ indexesInNewSchemaByName: buildSchemaObjMap(diff.new.Indexes),
+ indexRenamesByOldName: renameConflictingIndexSQLVertexGenerator.getRenames(),
+ }
+ indexGraphs, err := diff.indexDiffs.resolveToSQLGraph(&indexSQLVertexGenerator)
+ if err != nil {
+ return nil, fmt.Errorf("resolving index sql graphs: %w", err)
+ }
+
+ functionsInNewSchemaByName := buildSchemaObjMap(diff.new.Functions)
+
+ functionSQLVertexGenerator := functionSQLVertexGenerator{
+ functionsInNewSchemaByName: functionsInNewSchemaByName,
+ }
+ functionGraphs, err := diff.functionDiffs.resolveToSQLGraph(&functionSQLVertexGenerator)
+ if err != nil {
+ return nil, fmt.Errorf("resolving function sql graphs: %w", err)
+ }
+
+ triggerSQLVertexGenerator := triggerSQLVertexGenerator{
+ functionsInNewSchemaByName: functionsInNewSchemaByName,
+ }
+ triggerGraphs, err := diff.triggerDiffs.resolveToSQLGraph(&triggerSQLVertexGenerator)
+ if err != nil {
+ return nil, fmt.Errorf("resolving trigger sql graphs: %w", err)
+ }
+
+ if err := tableGraphs.union(attachPartitionGraphs); err != nil {
+ return nil, fmt.Errorf("unioning table and attach partition graphs: %w", err)
+ }
+ if err := tableGraphs.union(indexGraphs); err != nil {
+ return nil, fmt.Errorf("unioning table and index graphs: %w", err)
+ }
+ if err := tableGraphs.union(renameConflictingIndexGraphs); err != nil {
+ return nil, fmt.Errorf("unioning table and rename conflicting index graphs: %w", err)
+ }
+ if err := tableGraphs.union(functionGraphs); err != nil {
+ return nil, fmt.Errorf("unioning table and function graphs: %w", err)
+ }
+ if err := tableGraphs.union(triggerGraphs); err != nil {
+ return nil, fmt.Errorf("unioning table and trigger graphs: %w", err)
+ }
+
+ return tableGraphs.toOrderedStatements()
+}
+
+func buildSchemaObjMap[S schema.Object](s []S) map[string]S {
+ output := make(map[string]S)
+ for _, obj := range s {
+ output[obj.GetName()] = obj
+ }
+ return output
+}
+
+type tableSQLVertexGenerator struct {
+ deletedTablesByName map[string]schema.Table
+ tablesInNewSchemaByName map[string]schema.Table
+}
+
+var _ sqlVertexGenerator[schema.Table, tableDiff] = &tableSQLVertexGenerator{}
+
+func (t *tableSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) {
+ if table.IsPartition() {
+ if table.IsPartitioned() {
+ return nil, fmt.Errorf("partitioned partitions: %w", ErrNotImplemented)
+ }
+ if len(table.CheckConstraints) > 0 {
+ return nil, fmt.Errorf("check constraints on partitions: %w", ErrNotImplemented)
+ }
+ // We attach the partitions separately. So the partition must have all the same check constraints
+ // as the original table
+ table.CheckConstraints = append(table.CheckConstraints, t.tablesInNewSchemaByName[table.ParentTableName].CheckConstraints...)
+ }
+
+ var stmts []Statement
+
+ var columnDefs []string
+ for _, column := range table.Columns {
+ columnDefs = append(columnDefs, "\t"+buildColumnDefinition(column))
+ }
+ createTableSb := strings.Builder{}
+ createTableSb.WriteString(fmt.Sprintf("CREATE TABLE %s (\n%s\n)",
+ schema.EscapeIdentifier(table.Name),
+ strings.Join(columnDefs, ",\n"),
+ ))
+ if table.IsPartitioned() {
+ createTableSb.WriteString(fmt.Sprintf("PARTITION BY %s", table.PartitionKeyDef))
+ }
+ stmts = append(stmts, Statement{
+ DDL: createTableSb.String(),
+ Timeout: statementTimeoutDefault,
+ })
+
+ csg := checkConstraintSQLGenerator{tableName: table.Name}
+ for _, checkCon := range table.CheckConstraints {
+ addConStmts, err := csg.Add(checkCon)
+ if err != nil {
+ return nil, fmt.Errorf("generating add check constraint statements for check constraint %s: %w", checkCon.Name, err)
+ }
+ // Remove hazards from statements since the table is brand new
+ stmts = append(stmts, stripMigrationHazards(addConStmts)...)
+ }
+
+ return stmts, nil
+}
+
+func (t *tableSQLVertexGenerator) Delete(table schema.Table) ([]Statement, error) {
+ if table.IsPartition() {
+ // Don't support dropping partitions without dropping the base table. This would be easy to implement, but we
+ // would need to add tests for it.
+ //
+ // The base table might be recreated, so check if its deleted rather than just checking if it does not exist in
+ // the new schema
+ if _, baseTableDropped := t.deletedTablesByName[table.ParentTableName]; !baseTableDropped {
+ return nil, fmt.Errorf("deleting partitions without dropping parent table: %w", ErrNotImplemented)
+ }
+ // It will be dropped when the parent table is dropped
+ return nil, nil
+ }
+ return []Statement{
+ {
+ DDL: fmt.Sprintf("DROP TABLE %s", schema.EscapeIdentifier(table.Name)),
+ Timeout: statementTimeoutTableDrop,
+ Hazards: []MigrationHazard{{
+ Type: MigrationHazardTypeDeletesData,
+ Message: "Deletes all rows in the table (and the table itself)",
+ }},
+ },
+ }, nil
+}
+
+func (t *tableSQLVertexGenerator) Alter(diff tableDiff) ([]Statement, error) {
+ if diff.old.IsPartition() != diff.new.IsPartition() {
+ return nil, fmt.Errorf("changing a partition to no longer be a partition (or vice versa): %w", ErrNotImplemented)
+ } else if diff.new.IsPartition() {
+ return t.alterPartition(diff)
+ }
+
+ if diff.old.PartitionKeyDef != diff.new.PartitionKeyDef {
+ return nil, fmt.Errorf("changing partition key def: %w", ErrNotImplemented)
+ }
+
+ columnSQLGenerator := columnSQLGenerator{tableName: diff.new.Name}
+ columnGeneratedSQL, err := diff.columnsDiff.resolveToSQLGroupedByEffect(&columnSQLGenerator)
+ if err != nil {
+ return nil, fmt.Errorf("resolving index diff: %w", err)
+ }
+
+ checkConSQLGenerator := checkConstraintSQLGenerator{tableName: diff.new.Name}
+ checkConGeneratedSQL, err := diff.checkConstraintDiff.resolveToSQLGroupedByEffect(&checkConSQLGenerator)
+ if err != nil {
+ return nil, fmt.Errorf("Resolving check constraints diff: %w", err)
+ }
+
+ var stmts []Statement
+ stmts = append(stmts, checkConGeneratedSQL.Deletes...)
+ stmts = append(stmts, columnGeneratedSQL.Deletes...)
+ stmts = append(stmts, columnGeneratedSQL.Adds...)
+ stmts = append(stmts, checkConGeneratedSQL.Adds...)
+ stmts = append(stmts, columnGeneratedSQL.Alters...)
+ stmts = append(stmts, checkConGeneratedSQL.Alters...)
+ return stmts, nil
+}
+
+func (t *tableSQLVertexGenerator) alterPartition(diff tableDiff) ([]Statement, error) {
+ if diff.old.ForValues != diff.new.ForValues {
+ return nil, fmt.Errorf("altering partition FOR VALUES: %w", ErrNotImplemented)
+ }
+ if !diff.checkConstraintDiff.isEmpty() {
+ return nil, fmt.Errorf("check constraints on partitions: %w", ErrNotImplemented)
+ }
+
+ var stmts []Statement
+ // ColumnsDiff should only have nullability changes. Partitioned tables
+ // aren't concerned about old/new columns added
+ for _, colDiff := range diff.columnsDiff.alters {
+ if colDiff.old.IsNullable == colDiff.new.IsNullable {
+ continue
+ }
+ alterColumnPrefix := fmt.Sprintf("%s ALTER COLUMN %s", alterTablePrefix(diff.new.Name), schema.EscapeIdentifier(colDiff.new.Name))
+ if colDiff.new.IsNullable {
+ stmts = append(stmts, Statement{
+ DDL: fmt.Sprintf("%s DROP NOT NULL", alterColumnPrefix),
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{
+ {
+ Type: MigrationHazardTypeAcquiresAccessExclusiveLock,
+ Message: "Marking a column as not null requires a full table scan, which will lock out " +
+ "writes on the partition",
+ },
+ },
+ })
+ } else {
+ stmts = append(stmts, Statement{
+ DDL: fmt.Sprintf("%s SET NOT NULL", alterColumnPrefix),
+ Timeout: statementTimeoutDefault,
+ })
+ }
+ }
+
+ return stmts, nil
+}
+
+func (t *tableSQLVertexGenerator) GetSQLVertexId(table schema.Table) string {
+ return buildTableVertexId(table.Name)
+}
+
+func (t *tableSQLVertexGenerator) GetAddAlterDependencies(table, _ schema.Table) []dependency {
+ deps := []dependency{
+ mustRun(t.GetSQLVertexId(table), diffTypeAddAlter).after(t.GetSQLVertexId(table), diffTypeDelete),
+ }
+
+ if table.IsPartition() {
+ deps = append(deps,
+ mustRun(t.GetSQLVertexId(table), diffTypeAddAlter).after(buildTableVertexId(table.ParentTableName), diffTypeAddAlter),
+ )
+ }
+ return deps
+}
+
+func (t *tableSQLVertexGenerator) GetDeleteDependencies(table schema.Table) []dependency {
+ var deps []dependency
+ if table.IsPartition() {
+ deps = append(deps,
+ mustRun(t.GetSQLVertexId(table), diffTypeDelete).after(buildTableVertexId(table.ParentTableName), diffTypeDelete),
+ )
+ }
+ return deps
+}
+
+type columnSQLGenerator struct {
+ tableName string
+}
+
+func (csg *columnSQLGenerator) Add(column schema.Column) ([]Statement, error) {
+ return []Statement{{
+ DDL: fmt.Sprintf("%s ADD COLUMN %s", alterTablePrefix(csg.tableName), buildColumnDefinition(column)),
+ Timeout: statementTimeoutDefault,
+ }}, nil
+}
+
+func (csg *columnSQLGenerator) Delete(column schema.Column) ([]Statement, error) {
+ return []Statement{{
+ DDL: fmt.Sprintf("%s DROP COLUMN %s", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(column.Name)),
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{
+ {
+ Type: MigrationHazardTypeDeletesData,
+ Message: "Deletes all values in the column",
+ },
+ },
+ }}, nil
+}
+
+func (csg *columnSQLGenerator) Alter(diff columnDiff) ([]Statement, error) {
+ if diff.oldOrdering != diff.newOrdering {
+ return nil, fmt.Errorf("old=%d; new=%d: %w", diff.oldOrdering, diff.newOrdering, ErrColumnOrderingChanged)
+ }
+ oldColumn, newColumn := diff.old, diff.new
+ var stmts []Statement
+ alterColumnPrefix := fmt.Sprintf("%s ALTER COLUMN %s", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(newColumn.Name))
+
+ if oldColumn.IsNullable != newColumn.IsNullable {
+ if newColumn.IsNullable {
+ stmts = append(stmts, Statement{
+ DDL: fmt.Sprintf("%s DROP NOT NULL", alterColumnPrefix),
+ Timeout: statementTimeoutDefault,
+ })
+ } else {
+ stmts = append(stmts, Statement{
+ DDL: fmt.Sprintf("%s SET NOT NULL", alterColumnPrefix),
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{
+ {
+ Type: MigrationHazardTypeAcquiresAccessExclusiveLock,
+ Message: "Marking a column as not null requires a full table scan, which will lock out writes",
+ },
+ },
+ })
+ }
+ }
+
+ if len(oldColumn.Default) > 0 && len(newColumn.Default) == 0 {
+ // Drop the default before type conversion. This will allow type conversions
+ // between incompatible types if the previous column has a default and the new column is dropping its default
+ stmts = append(stmts, Statement{
+ DDL: fmt.Sprintf("%s DROP DEFAULT", alterColumnPrefix),
+ Timeout: statementTimeoutDefault,
+ })
+ }
+
+ if !strings.EqualFold(oldColumn.Type, newColumn.Type) ||
+ !strings.EqualFold(oldColumn.Collation.GetFQEscapedName(), newColumn.Collation.GetFQEscapedName()) {
+ stmts = append(stmts,
+ []Statement{
+ csg.generateTypeTransformationStatement(
+ alterColumnPrefix,
+ schema.EscapeIdentifier(newColumn.Name),
+ oldColumn.Type,
+ newColumn.Type,
+ newColumn.Collation,
+ ),
+ // When "SET TYPE" is used to alter a column, that column's statistics are removed, which could
+ // affect query plans. In order to mitigate the effect on queries, re-generate the statistics for the
+ // column before continuing with the migration.
+ {
+ DDL: fmt.Sprintf("ANALYZE %s (%s)", schema.EscapeIdentifier(csg.tableName), schema.EscapeIdentifier(newColumn.Name)),
+ Timeout: statementTimeoutAnalyzeColumn,
+ Hazards: []MigrationHazard{
+ {
+ Type: MigrationHazardTypeImpactsDatabasePerformance,
+ Message: "Running analyze will read rows from the table, putting increased load " +
+ "on the database and consuming database resources. It won't prevent reads/writes to " +
+ "the table, but it could affect performance when executing queries.",
+ },
+ },
+ },
+ }...)
+ }
+
+ if oldColumn.Default != newColumn.Default && len(newColumn.Default) > 0 {
+ // Set the default after the type conversion. This will allow type conversions
+ // between incompatible types if the previous column has no default and the new column has a default
+ stmts = append(stmts, Statement{
+ DDL: fmt.Sprintf("%s SET DEFAULT %s", alterColumnPrefix, newColumn.Default),
+ Timeout: statementTimeoutDefault,
+ })
+ }
+
+ return stmts, nil
+}
+
+func (csg *columnSQLGenerator) generateTypeTransformationStatement(
+ prefix string,
+ name string,
+ oldType string,
+ newType string,
+ newTypeCollation schema.SchemaQualifiedName,
+) Statement {
+ if strings.EqualFold(oldType, "bigint") &&
+ strings.EqualFold(newType, "timestamp without time zone") {
+ return Statement{
+ DDL: fmt.Sprintf("%s SET DATA TYPE %s using to_timestamp(%s / 1000)",
+ prefix,
+ newType,
+ name,
+ ),
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{{
+ Type: MigrationHazardTypeAcquiresAccessExclusiveLock,
+ Message: "This will completely lock the table while the data is being " +
+ "re-written for a duration of time that scales with the size of your data. " +
+ "The values previously stored as BIGINT will be translated into a " +
+ "TIMESTAMP value via the PostgreSQL to_timestamp() function. This " +
+ "translation will assume that the values stored in BIGINT represent a " +
+ "millisecond epoch value.",
+ }},
+ }
+ }
+
+ collationModifier := ""
+ if !newTypeCollation.IsEmpty() {
+ collationModifier = fmt.Sprintf("COLLATE %s ", newTypeCollation.GetFQEscapedName())
+ }
+
+ return Statement{
+ DDL: fmt.Sprintf("%s SET DATA TYPE %s %susing %s::%s",
+ prefix,
+ newType,
+ collationModifier,
+ name,
+ newType,
+ ),
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{{
+ Type: MigrationHazardTypeAcquiresAccessExclusiveLock,
+ Message: "This will completely lock the table while the data is being re-written. " +
+ "The duration of this conversion depends on if the type conversion is trivial " +
+ "or not. A non-trivial conversion will require a table rewrite. A trivial " +
+ "conversion is one where the binary values are coercible and the column " +
+ "contents are not changing.",
+ }},
+ }
+}
+
+type renameConflictingIndexSQLVertexGenerator struct {
+ // indexesInOldSchemaByName is a map of index name to the index in the old schema
+ // It is used to identify if an index has been re-created
+ oldSchemaIndexesByName map[string]schema.Index
+
+ indexRenamesByOldName map[string]string
+}
+
+func newRenameConflictingIndexSQLVertexGenerator(oldSchemaIndexesByName map[string]schema.Index) renameConflictingIndexSQLVertexGenerator {
+ return renameConflictingIndexSQLVertexGenerator{
+ oldSchemaIndexesByName: oldSchemaIndexesByName,
+ indexRenamesByOldName: make(map[string]string),
+ }
+}
+
+func (rsg *renameConflictingIndexSQLVertexGenerator) Add(index schema.Index) ([]Statement, error) {
+ if oldIndex, indexIsBeingRecreated := rsg.oldSchemaIndexesByName[index.Name]; !indexIsBeingRecreated {
+ return nil, nil
+ } else if oldIndex.IsPk && index.IsPk {
+ // Don't bother renaming if both are primary keys, since the new index will need to be created after the old
+ // index because we can't have two primary keys at the same time.
+ //
+ // To make changing primary keys index-gap free (mostly online), we could build the underlying new primary key index,
+ // drop the old primary constraint (and index), and then add the primary key constraint using the new index.
+ // This would require us to split primary key constraint SQL generation from index SQL generation
+ return nil, nil
+ }
+
+ newName, err := rsg.generateNonConflictingName(index)
+ if err != nil {
+ return nil, fmt.Errorf("generating non-conflicting name: %w", err)
+ }
+
+ rsg.indexRenamesByOldName[index.Name] = newName
+
+ return []Statement{{
+ DDL: fmt.Sprintf("ALTER INDEX %s RENAME TO %s", schema.EscapeIdentifier(index.Name), schema.EscapeIdentifier(newName)),
+ Timeout: statementTimeoutDefault,
+ }}, nil
+}
+
+func (rsg *renameConflictingIndexSQLVertexGenerator) generateNonConflictingName(index schema.Index) (string, error) {
+ uuid, err := uuid.NewRandom()
+ if err != nil {
+ return "", fmt.Errorf("generating UUID: %w", err)
+ }
+
+ newNameSuffix := fmt.Sprintf("_%s", uuid.String())
+ idxNameTruncationIdx := len(index.Name)
+ if len(index.Name) > maxPostgresIdentifierSize-len(newNameSuffix) {
+ idxNameTruncationIdx = maxPostgresIdentifierSize - len(newNameSuffix)
+ }
+
+ return index.Name[:idxNameTruncationIdx] + newNameSuffix, nil
+}
+
+func (rsg *renameConflictingIndexSQLVertexGenerator) getRenames() map[string]string {
+ return rsg.indexRenamesByOldName
+}
+
+func (rsg *renameConflictingIndexSQLVertexGenerator) Delete(_ schema.Index) ([]Statement, error) {
+ return nil, nil
+}
+
+func (rsg *renameConflictingIndexSQLVertexGenerator) Alter(_ indexDiff) ([]Statement, error) {
+ return nil, nil
+}
+
+func (*renameConflictingIndexSQLVertexGenerator) GetSQLVertexId(index schema.Index) string {
+ return buildRenameConflictingIndexVertexId(index.Name)
+}
+
+func (rsg *renameConflictingIndexSQLVertexGenerator) GetAddAlterDependencies(_, _ schema.Index) []dependency {
+ return nil
+}
+
+func (rsg *renameConflictingIndexSQLVertexGenerator) GetDeleteDependencies(_ schema.Index) []dependency {
+ return nil
+}
+
+func buildRenameConflictingIndexVertexId(indexName string) string {
+ return buildVertexId("indexrename", indexName)
+}
+
+type indexSQLVertexGenerator struct {
+ // deletedTablesByName is a map of table name to the deleted tables (and partitions)
+ deletedTablesByName map[string]schema.Table
+ // addedTablesByName is a map of table name to the new tables (and partitions)
+ // This is used to identify if hazards are necessary
+ addedTablesByName map[string]schema.Table
+ // tablesInNewSchemaByName is a map of table name to tables (and partitions) in the new schema.
+ // These tables are not necessarily new. This is used to identify if the table is partitioned
+ tablesInNewSchemaByName map[string]schema.Table
+ // indexesInNewSchemaByName is a map of index name to the index
+ // This is used to identify the parent index is a primary key
+ indexesInNewSchemaByName map[string]schema.Index
+ // indexRenamesByOldName is a map of any renames performed by the conflicting index sql vertex generator
+ indexRenamesByOldName map[string]string
+}
+
+func (isg *indexSQLVertexGenerator) Add(index schema.Index) ([]Statement, error) {
+ stmts, err := isg.addIdxStmtsWithHazards(index)
+ if err != nil {
+ return stmts, err
+ }
+
+ if _, isNewTable := isg.addedTablesByName[index.TableName]; isNewTable {
+ stmts = stripMigrationHazards(stmts)
+ }
+ return stmts, nil
+}
+
+func (isg *indexSQLVertexGenerator) addIdxStmtsWithHazards(index schema.Index) ([]Statement, error) {
+ if index.IsInvalid {
+ return nil, fmt.Errorf("can't create an invalid index: %w", ErrNotImplemented)
+ }
+
+ if index.IsPk {
+ if index.IsPartitionOfIndex() {
+ if parentIdx, ok := isg.indexesInNewSchemaByName[index.ParentIdxName]; !ok {
+ return nil, fmt.Errorf("could not find parent index %s", index.ParentIdxName)
+ } else if parentIdx.IsPk {
+ // All indexes associated with parent primary keys are automatically created by their parent
+ return nil, nil
+ }
+ }
+
+ if isOnPartitionedTable, err := isg.isOnPartitionedTable(index); err != nil {
+ return nil, err
+ } else if isOnPartitionedTable {
+ // A partitioned table can't have a constraint added to it with "USING INDEX", so just make the index
+ // automatically through the constraint. This currently blocks and is a dangerous operation
+ // If users want to be able to switch primary keys concurrently, support can be added in the future,
+ // with a similar strategy to adding indexes to a partitioned table
+ return []Statement{
+ {
+ DDL: fmt.Sprintf("%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
+ alterTablePrefix(index.TableName),
+ schema.EscapeIdentifier(index.Name),
+ strings.Join(formattedNamesForSQL(index.Columns), ", "),
+ ),
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{
+ {
+ Type: MigrationHazardTypeAcquiresShareLock,
+ Message: "This will lock writes to the table while the index build occurs.",
+ },
+ {
+ Type: MigrationHazardTypeIndexBuild,
+ Message: "This is non-concurrent because adding PK's concurrently hasn't been" +
+ "implemented yet. It WILL lock out writes. Index builds require a non-trivial " +
+ "amount of CPU as well, which might affect database performance",
+ },
+ },
+ },
+ }, nil
+ }
+ }
+
+ var stmts []Statement
+ var createIdxStmtHazards []MigrationHazard
+
+ createIdxStmt := string(index.GetIndexDefStmt)
+ createIdxStmtTimeout := statementTimeoutDefault
+ if isOnPartitionedTable, err := isg.isOnPartitionedTable(index); err != nil {
+ return nil, err
+ } else if !isOnPartitionedTable {
+ // Only indexes on non-partitioned tables can be created concurrently
+ concurrentCreateIdxStmt, err := index.GetIndexDefStmt.ToCreateIndexConcurrently()
+ if err != nil {
+ return nil, fmt.Errorf("modifying index def statement to concurrently: %w", err)
+ }
+ createIdxStmt = concurrentCreateIdxStmt
+ createIdxStmtHazards = append(createIdxStmtHazards, MigrationHazard{
+ Type: MigrationHazardTypeIndexBuild,
+ Message: "This might affect database performance. " +
+ "Concurrent index builds require a non-trivial amount of CPU, potentially affecting database performance. " +
+ "They also can take a while but do not lock out writes.",
+ })
+ createIdxStmtTimeout = statementTimeoutConcurrentIndexBuild
+ }
+
+ stmts = append(stmts, Statement{
+ DDL: createIdxStmt,
+ Timeout: createIdxStmtTimeout,
+ Hazards: createIdxStmtHazards,
+ })
+
+ _, isNewTable := isg.addedTablesByName[index.TableName]
+ if index.IsPartitionOfIndex() && !isNewTable {
+ // Exclude if the partition is new because the index will be attached when the partition is attached
+ stmts = append(stmts, buildAttachIndex(index))
+ }
+
+ if index.IsPk {
+ stmts = append(stmts, isg.addPkConstraintUsingIdx(index))
+ } else if len(index.ConstraintName) > 0 {
+ return nil, fmt.Errorf("constraints not supported for non-primary key indexes: %w", ErrNotImplemented)
+ }
+ return stmts, nil
+}
+
+func (isg *indexSQLVertexGenerator) Delete(index schema.Index) ([]Statement, error) {
+ _, tableWasDeleted := isg.deletedTablesByName[index.TableName]
+ // An index will be dropped if its owning table is dropped.
+ // Similarly, a partition of an index will be dropped when the parent index is dropped
+ if tableWasDeleted || index.IsPartitionOfIndex() {
+ return nil, nil
+ }
+
+ // An index used by a primary key constraint/unique constraint cannot be dropped concurrently
+ if len(index.ConstraintName) > 0 {
+ // The index has been potentially renamed, which causes the constraint to be renamed. Use the updated name
+ constraintName := index.ConstraintName
+ if rename, hasRename := isg.indexRenamesByOldName[index.Name]; hasRename {
+ constraintName = rename
+ }
+
+ // Dropping the constraint will automatically drop the index. There is no way to drop
+ // the constraint without dropping the index
+ return []Statement{
+ {
+ DDL: dropConstraintDDL(index.TableName, constraintName),
+ Timeout: statementTimeoutDefault,
+ Hazards: []MigrationHazard{
+ migrationHazardIndexDroppedAcquiresLock,
+ migrationHazardIndexDroppedQueryPerf,
+ },
+ },
+ }, nil
+ }
+
+ var dropIndexStmtHazards []MigrationHazard
+ concurrentlyModifier := "CONCURRENTLY "
+ dropIndexStmtTimeout := statementTimeoutConcurrentIndexDrop
+ if isOnPartitionedTable, err := isg.isOnPartitionedTable(index); err != nil {
+ return nil, err
+ } else if isOnPartitionedTable {
+ // Currently, postgres has no good way of dropping an index partition concurrently
+ concurrentlyModifier = ""
+ dropIndexStmtTimeout = statementTimeoutDefault
+ // Technically, CONCURRENTLY also locks the table, but it waits for an "opportunity" to lock
+ // We will omit the locking hazard of concurrent drops for now
+ dropIndexStmtHazards = append(dropIndexStmtHazards, migrationHazardIndexDroppedAcquiresLock)
+ }
+
+ // The index has been potentially renamed. Use the updated name
+ indexName := index.Name
+ if rename, hasRename := isg.indexRenamesByOldName[index.Name]; hasRename {
+ indexName = rename
+ }
+
+ return []Statement{{
+ DDL: fmt.Sprintf("DROP INDEX %s%s", concurrentlyModifier, schema.EscapeIdentifier(indexName)),
+ Timeout: dropIndexStmtTimeout,
+ Hazards: append(dropIndexStmtHazards, migrationHazardIndexDroppedQueryPerf),
+ }}, nil
+}
+
+func (isg *indexSQLVertexGenerator) Alter(diff indexDiff) ([]Statement, error) {
+ var stmts []Statement
+
+ if isOnPartitionedTable, err := isg.isOnPartitionedTable(diff.new); err != nil {
+ return nil, err
+ } else if isOnPartitionedTable && diff.old.IsInvalid && !diff.new.IsInvalid {
+ // If the index is a partitioned index, it can be made valid automatically by attaching the index partitions
+ diff.old.IsInvalid = diff.new.IsInvalid
+ }
+
+ if !diff.new.IsPartitionOfIndex() && !diff.old.IsPk && diff.new.IsPk {
+ stmts = append(stmts, isg.addPkConstraintUsingIdx(diff.new))
+ diff.old.IsPk = diff.new.IsPk
+ diff.old.ConstraintName = diff.new.ConstraintName
+ }
+
+ if len(diff.old.ParentIdxName) == 0 && len(diff.new.ParentIdxName) > 0 {
+ stmts = append(stmts, buildAttachIndex(diff.new))
+ diff.old.ParentIdxName = diff.new.ParentIdxName
+ }
+
+ if !cmp.Equal(diff.old, diff.new) {
+ return nil, fmt.Errorf("index diff could not be resolved %s", cmp.Diff(diff.old, diff.new))
+ }
+
+ return stmts, nil
+}
+
+func (isg *indexSQLVertexGenerator) isOnPartitionedTable(index schema.Index) (bool, error) {
+ return isOnPartitionedTable(isg.tablesInNewSchemaByName, index)
+}
+
+// Returns true if the table the index belongs too is partitioned. If the table is a partition of a
+// partitioned table, this will always return false
+func isOnPartitionedTable(tablesInNewSchemaByName map[string]schema.Table, index schema.Index) (bool, error) {
+ if owningTable, ok := tablesInNewSchemaByName[index.TableName]; !ok {
+ return false, fmt.Errorf("could not find table in new schema with name %s", index.TableName)
+ } else {
+ return owningTable.IsPartitioned(), nil
+ }
+}
+
+func (isg *indexSQLVertexGenerator) addPkConstraintUsingIdx(index schema.Index) Statement {
+ return Statement{
+ DDL: fmt.Sprintf("%s ADD CONSTRAINT %s PRIMARY KEY USING INDEX %s", alterTablePrefix(index.TableName), schema.EscapeIdentifier(index.ConstraintName), schema.EscapeIdentifier(index.Name)),
+ Timeout: statementTimeoutDefault,
+ }
+}
+
+func buildAttachIndex(index schema.Index) Statement {
+ return Statement{
+ DDL: fmt.Sprintf("ALTER INDEX %s ATTACH PARTITION %s", schema.EscapeIdentifier(index.ParentIdxName), schema.EscapeIdentifier(index.Name)),
+ Timeout: statementTimeoutDefault,
+ }
+}
+
+func (*indexSQLVertexGenerator) GetSQLVertexId(index schema.Index) string {
+ return buildIndexVertexId(index.Name)
+}
+
+func (isg *indexSQLVertexGenerator) GetAddAlterDependencies(index, _ schema.Index) []dependency {
+ dependencies := []dependency{
+ mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).after(buildTableVertexId(index.TableName), diffTypeAddAlter),
+ // To allow for online changes to indexes, rename the older version of the index (if it exists) before the new version is added
+ mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).after(buildRenameConflictingIndexVertexId(index.Name), diffTypeAddAlter),
+ }
+
+ if index.IsPartitionOfIndex() {
+ // Partitions of indexes must be created after the parent index is created
+ dependencies = append(dependencies,
+ mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).after(buildIndexVertexId(index.ParentIdxName), diffTypeAddAlter))
+ }
+
+ return dependencies
+}
+
+func (isg *indexSQLVertexGenerator) GetDeleteDependencies(index schema.Index) []dependency {
+ dependencies := []dependency{
+ mustRun(isg.GetSQLVertexId(index), diffTypeDelete).after(buildTableVertexId(index.TableName), diffTypeDelete),
+ // Drop the index after it has been potentially renamed
+ mustRun(isg.GetSQLVertexId(index), diffTypeDelete).after(buildRenameConflictingIndexVertexId(index.Name), diffTypeAddAlter),
+ }
+
+ if index.IsPartitionOfIndex() {
+ // Since dropping the parent index will cause the partition of the index to drop, the parent drop should come
+ // before
+ dependencies = append(dependencies,
+ mustRun(isg.GetSQLVertexId(index), diffTypeDelete).after(buildIndexVertexId(index.ParentIdxName), diffTypeDelete))
+ }
+ dependencies = append(dependencies, isg.addDepsOnTableAddAlterIfNecessary(index)...)
+
+ return dependencies
+}
+
+func (isg *indexSQLVertexGenerator) addDepsOnTableAddAlterIfNecessary(index schema.Index) []dependency {
+ // This could be cleaner if start sorting columns separately in the graph
+
+ parentTable, ok := isg.tablesInNewSchemaByName[index.TableName]
+ if !ok {
+ // If the parent table is deleted, we don't need to worry about making the index statement come
+ // before any alters
+ return nil
+ }
+
+ // These dependencies will force the index deletion statement to come before the table AddAlter
+ addAlterColumnDeps := []dependency{
+ mustRun(isg.GetSQLVertexId(index), diffTypeDelete).before(buildTableVertexId(index.TableName), diffTypeAddAlter),
+ }
+ if len(parentTable.ParentTableName) > 0 {
+ // If the table is partitioned, columns modifications occur on the base table not the children. Thus, we
+ // need the dependency to also be on the parent table add/alter statements
+ addAlterColumnDeps = append(
+ addAlterColumnDeps,
+ mustRun(isg.GetSQLVertexId(index), diffTypeDelete).before(buildTableVertexId(parentTable.ParentTableName), diffTypeAddAlter),
+ )
+ }
+
+ // If the parent table still exists and the index is a primary key, we should drop the PK index before
+ // any statements associated with altering the table run. This is important for changing the nullability of
+ // columns
+ if index.IsPk {
+ return addAlterColumnDeps
+ }
+
+ parentTableColumnsByName := buildSchemaObjMap(parentTable.Columns)
+ for _, idxColumn := range index.Columns {
+ // We need to force the index drop to come before the statements to drop columns. Otherwise, the columns
+ // drops will force the index to drop non-concurrently
+ if _, columnStillPresent := parentTableColumnsByName[idxColumn]; !columnStillPresent {
+ return addAlterColumnDeps
+ }
+ }
+
+ return nil
+}
+
+type checkConstraintSQLGenerator struct {
+ tableName string
+}
+
+func (csg *checkConstraintSQLGenerator) Add(con schema.CheckConstraint) ([]Statement, error) {
+ // UDF's in check constraints are a bad idea. Check constraints are not re-validated
+ // if the UDF changes, so it's not really a safe practice. We won't support it for now
+ if len(con.DependsOnFunctions) > 0 {
+ return nil, fmt.Errorf("check constraints that depend on UDFs: %w", ErrNotImplemented)
+ }
+
+ sb := strings.Builder{}
+ sb.WriteString(fmt.Sprintf("%s ADD CONSTRAINT %s CHECK(%s)", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(con.Name), con.Expression))
+ if !con.IsInheritable {
+ sb.WriteString(" NO INHERIT")
+ }
+
+ if !con.IsValid {
+ sb.WriteString(" NOT VALID")
+ }
+
+ return []Statement{{
+ DDL: sb.String(),
+ Timeout: statementTimeoutDefault,
+ }}, nil
+}
+
+func (csg *checkConstraintSQLGenerator) Delete(con schema.CheckConstraint) ([]Statement, error) {
+ // We won't support deleting check constraints depending on UDF's to align with not supporting adding check
+ // constraints that depend on UDF's
+ if len(con.DependsOnFunctions) > 0 {
+ return nil, fmt.Errorf("check constraints that depend on UDFs: %w", ErrNotImplemented)
+ }
+
+ return []Statement{{
+ DDL: dropConstraintDDL(csg.tableName, con.Name),
+ Timeout: statementTimeoutDefault,
+ }}, nil
+}
+
+func (csg *checkConstraintSQLGenerator) Alter(diff checkConstraintDiff) ([]Statement, error) {
+ if cmp.Equal(diff.old, diff.new) {
+ return nil, nil
+ }
+
+ oldCopy := diff.old
+ oldCopy.IsValid = diff.new.IsValid
+ if !cmp.Equal(oldCopy, diff.new) {
+ // Technically, we could support altering expression, but I don't see the use case for it. it would require more test
+ // cases than forceReadding it, and I'm not convinced it unlocks any functionality
+ return nil, fmt.Errorf("altering check constraint to resolve the following diff %s: %w", cmp.Diff(oldCopy, diff.new), ErrNotImplemented)
+ } else if diff.old.IsValid && !diff.new.IsValid {
+ return nil, fmt.Errorf("check constraint can't go from invalid to valid")
+ } else if len(diff.old.DependsOnFunctions) > 0 || len(diff.new.DependsOnFunctions) > 0 {
+ return nil, fmt.Errorf("check constraints that depend on UDFs: %w", ErrNotImplemented)
+ }
+
+ return []Statement{{
+ DDL: fmt.Sprintf("%s VALIDATE CONSTRAINT %s", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(diff.old.Name)),
+ Timeout: statementTimeoutDefault,
+ }}, nil
+}
+
+type attachPartitionSQLVertexGenerator struct {
+ indexesInNewSchemaByTableName map[string][]schema.Index
+}
+
+func (*attachPartitionSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) {
+ if !table.IsPartition() {
+ return nil, nil
+ }
+ return []Statement{buildAttachPartitionStatement(table)}, nil
+}
+
+func (*attachPartitionSQLVertexGenerator) Alter(_ tableDiff) ([]Statement, error) {
+ return nil, nil
+}
+
+func buildAttachPartitionStatement(table schema.Table) Statement {
+ return Statement{
+ DDL: fmt.Sprintf("%s ATTACH PARTITION %s %s", alterTablePrefix(table.ParentTableName), schema.EscapeIdentifier(table.Name), table.ForValues),
+ Timeout: statementTimeoutDefault,
+ }
+}
+
+func (*attachPartitionSQLVertexGenerator) Delete(_ schema.Table) ([]Statement, error) {
+ return nil, nil
+}
+
+func (*attachPartitionSQLVertexGenerator) GetSQLVertexId(table schema.Table) string {
+ return fmt.Sprintf("attachpartition_%s", table.Name)
+}
+
+func (a *attachPartitionSQLVertexGenerator) GetAddAlterDependencies(table, _ schema.Table) []dependency {
+ deps := []dependency{
+ mustRun(a.GetSQLVertexId(table), diffTypeAddAlter).after(buildTableVertexId(table.Name), diffTypeAddAlter),
+ }
+
+ for _, idx := range a.indexesInNewSchemaByTableName[table.Name] {
+ deps = append(deps, mustRun(a.GetSQLVertexId(table), diffTypeAddAlter).after(buildIndexVertexId(idx.Name), diffTypeAddAlter))
+ }
+ return deps
+}
+
+func (a *attachPartitionSQLVertexGenerator) GetDeleteDependencies(_ schema.Table) []dependency {
+ return nil
+}
+
+type functionSQLVertexGenerator struct {
+ // functionsInNewSchemaByName is a map of function new to functions in the new schema.
+ // These functions are not necessarily new
+ functionsInNewSchemaByName map[string]schema.Function
+}
+
+func (f *functionSQLVertexGenerator) Add(function schema.Function) ([]Statement, error) {
+ var hazards []MigrationHazard
+ if !canFunctionDependenciesBeTracked(function) {
+ hazards = append(hazards, migrationHazardAddAlterFunctionCannotTrackDependencies)
+ }
+ return []Statement{{
+ DDL: function.FunctionDef,
+ Timeout: statementTimeoutDefault,
+ Hazards: hazards,
+ }}, nil
+}
+
+func (f *functionSQLVertexGenerator) Delete(function schema.Function) ([]Statement, error) {
+ var hazards []MigrationHazard
+ if !canFunctionDependenciesBeTracked(function) {
+ hazards = append(hazards, MigrationHazard{
+ Type: MigrationHazardTypeHasUntrackableDependencies,
+ Message: "Dependencies, i.e. other functions used in the function body, of non-sql functions cannot be " +
+ "tracked. As a result, we cannot guarantee that function dependencies are ordered properly relative to " +
+ "this statement. For drops, this means you need to ensure that all functions this function depends on " +
+ "are dropped after this statement.",
+ })
+ }
+ return []Statement{{
+ DDL: fmt.Sprintf("DROP FUNCTION %s", function.GetFQEscapedName()),
+ Timeout: statementTimeoutDefault,
+ Hazards: hazards,
+ }}, nil
+}
+
+func (f *functionSQLVertexGenerator) Alter(diff functionDiff) ([]Statement, error) {
+ if cmp.Equal(diff.old, diff.new) {
+ return nil, nil
+ }
+
+ var hazards []MigrationHazard
+ if !canFunctionDependenciesBeTracked(diff.new) {
+ hazards = append(hazards, migrationHazardAddAlterFunctionCannotTrackDependencies)
+ }
+ return []Statement{{
+ DDL: diff.new.FunctionDef,
+ Timeout: statementTimeoutDefault,
+ Hazards: hazards,
+ }}, nil
+}
+
+func canFunctionDependenciesBeTracked(function schema.Function) bool {
+ return function.Language == "sql"
+}
+
+func (f *functionSQLVertexGenerator) GetSQLVertexId(function schema.Function) string {
+ return buildFunctionVertexId(function.SchemaQualifiedName)
+}
+
+func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFunction schema.Function) []dependency {
+ // Since functions can just be `CREATE OR REPLACE`, there will never be a case where a function is
+ // added and dropped in the same migration. Thus, we don't need a dependency on the delete vertex of a function
+ // because there won't be one if it is being added/altered
+ var deps []dependency
+ for _, depFunction := range newFunction.DependsOnFunctions {
+ deps = append(deps, mustRun(f.GetSQLVertexId(newFunction), diffTypeAddAlter).after(buildFunctionVertexId(depFunction), diffTypeAddAlter))
+ }
+
+ if !cmp.Equal(oldFunction, schema.Function{}) {
+ // If the function is being altered:
+ // If the old version of the function calls other functions that are being deleted come, those deletions
+ // must come after the function is altered, so it is no longer dependent on those dropped functions
+ for _, depFunction := range oldFunction.DependsOnFunctions {
+ deps = append(deps, mustRun(f.GetSQLVertexId(newFunction), diffTypeAddAlter).before(buildFunctionVertexId(depFunction), diffTypeDelete))
+ }
+ }
+
+ return deps
+}
+
+func (f *functionSQLVertexGenerator) GetDeleteDependencies(function schema.Function) []dependency {
+ var deps []dependency
+ for _, depFunction := range function.DependsOnFunctions {
+ deps = append(deps, mustRun(f.GetSQLVertexId(function), diffTypeDelete).before(buildFunctionVertexId(depFunction), diffTypeDelete))
+ }
+ return deps
+}
+
+func buildFunctionVertexId(name schema.SchemaQualifiedName) string {
+ return buildVertexId("function", name.GetFQEscapedName())
+}
+
+type triggerSQLVertexGenerator struct {
+ // functionsInNewSchemaByName is a map of function new to functions in the new schema.
+ // These functions are not necessarily new
+ functionsInNewSchemaByName map[string]schema.Function
+}
+
+func (t *triggerSQLVertexGenerator) Add(trigger schema.Trigger) ([]Statement, error) {
+ return []Statement{{
+ DDL: string(trigger.GetTriggerDefStmt),
+ Timeout: statementTimeoutDefault,
+ }}, nil
+}
+
+func (t *triggerSQLVertexGenerator) Delete(trigger schema.Trigger) ([]Statement, error) {
+ return []Statement{{
+ DDL: fmt.Sprintf("DROP TRIGGER %s ON %s", trigger.EscapedName, trigger.OwningTable.GetFQEscapedName()),
+ Timeout: statementTimeoutDefault,
+ }}, nil
+}
+
+func (t *triggerSQLVertexGenerator) Alter(diff triggerDiff) ([]Statement, error) {
+ if cmp.Equal(diff.old, diff.new) {
+ return nil, nil
+ }
+
+ createOrReplaceStmt, err := diff.new.GetTriggerDefStmt.ToCreateOrReplace()
+ if err != nil {
+ return nil, fmt.Errorf("modifying get trigger def statement to create or replace: %w", err)
+ }
+ return []Statement{{
+ DDL: createOrReplaceStmt,
+ Timeout: statementTimeoutDefault,
+ }}, nil
+}
+
+func (t *triggerSQLVertexGenerator) GetSQLVertexId(trigger schema.Trigger) string {
+ return buildVertexId("trigger", trigger.GetName())
+}
+
+func (t *triggerSQLVertexGenerator) GetAddAlterDependencies(newTrigger, oldTrigger schema.Trigger) []dependency {
+ // Since a trigger can just be `CREATE OR REPLACE`, there will never be a case where a trigger is
+ // added and dropped in the same migration. Thus, we don't need a dependency on the delete node of a function
+ // because there won't be one if it is being added/altered
+ deps := []dependency{
+ mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).after(buildFunctionVertexId(newTrigger.Function), diffTypeAddAlter),
+ mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).after(buildTableVertexId(newTrigger.OwningTableUnescapedName), diffTypeAddAlter),
+ }
+
+ if !cmp.Equal(oldTrigger, schema.Trigger{}) {
+ // If the trigger is being altered:
+ // If the old version of the trigger called a function being deleted, the function deletion must come after the
+ // trigger is altered, so the trigger no longer has a dependency on the function
+ deps = append(deps,
+ mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).before(buildFunctionVertexId(oldTrigger.Function), diffTypeDelete),
+ )
+ }
+
+ return deps
+}
+
+func (t *triggerSQLVertexGenerator) GetDeleteDependencies(trigger schema.Trigger) []dependency {
+ return []dependency{
+ mustRun(t.GetSQLVertexId(trigger), diffTypeDelete).before(buildFunctionVertexId(trigger.Function), diffTypeDelete),
+ mustRun(t.GetSQLVertexId(trigger), diffTypeDelete).before(buildTableVertexId(trigger.OwningTableUnescapedName), diffTypeDelete),
+ }
+}
+
+func buildVertexId(objType string, id string) string {
+ return fmt.Sprintf("%s_%s", objType, id)
+}
+
+func stripMigrationHazards(stmts []Statement) []Statement {
+ var noHazardsStmts []Statement
+ for _, stmt := range stmts {
+ stmt.Hazards = nil
+ noHazardsStmts = append(noHazardsStmts, stmt)
+ }
+ return noHazardsStmts
+}
+
+func dropConstraintDDL(tableName, constraintName string) string {
+ return fmt.Sprintf("%s DROP CONSTRAINT %s", alterTablePrefix(tableName), schema.EscapeIdentifier(constraintName))
+}
+
+func alterTablePrefix(tableName string) string {
+ return fmt.Sprintf("ALTER TABLE %s", schema.EscapeIdentifier(tableName))
+}
+
+func buildColumnDefinition(column schema.Column) string {
+ sb := strings.Builder{}
+ sb.WriteString(fmt.Sprintf("%s %s", schema.EscapeIdentifier(column.Name), column.Type))
+ if column.IsCollated() {
+ sb.WriteString(fmt.Sprintf(" COLLATE %s", column.Collation.GetFQEscapedName()))
+ }
+ if !column.IsNullable {
+ sb.WriteString(" NOT NULL")
+ }
+ if len(column.Default) > 0 {
+ sb.WriteString(fmt.Sprintf(" DEFAULT %s", column.Default))
+ }
+ return sb.String()
+}
+
+func formattedNamesForSQL(names []string) []string {
+ var formattedNames []string
+ for _, name := range names {
+ formattedNames = append(formattedNames, schema.EscapeIdentifier(name))
+ }
+ return formattedNames
+}
diff --git a/pkg/diff/sql_graph.go b/pkg/diff/sql_graph.go
new file mode 100644
index 0000000..9088562
--- /dev/null
+++ b/pkg/diff/sql_graph.go
@@ -0,0 +1,66 @@
+package diff
+
+import (
+ "fmt"
+
+ "github.com/stripe/pg-schema-diff/internal/graph"
+)
+
+type sqlVertex struct {
+ ObjId string
+ Statements []Statement
+ DiffType diffType
+}
+
+func (s sqlVertex) GetId() string {
+ return fmt.Sprintf("%s_%s", s.DiffType, s.ObjId)
+}
+
+func buildTableVertexId(name string) string {
+ return fmt.Sprintf("table_%s", name)
+}
+
+func buildIndexVertexId(name string) string {
+ return fmt.Sprintf("index_%s", name)
+}
+
+// sqlGraph represents two dependency webs of SQL statements
+type sqlGraph graph.Graph[sqlVertex]
+
+// union unions the two AddsAndAlters graphs and, separately, unions the two delete graphs
+func (s *sqlGraph) union(sqlGraph *sqlGraph) error {
+ if err := (*graph.Graph[sqlVertex])(s).Union((*graph.Graph[sqlVertex])(sqlGraph), mergeSQLVertices); err != nil {
+ return fmt.Errorf("unioning the graphs: %w", err)
+ }
+ return nil
+}
+
+func mergeSQLVertices(old, new sqlVertex) sqlVertex {
+ return sqlVertex{
+ ObjId: old.ObjId,
+ DiffType: old.DiffType,
+ Statements: append(old.Statements, new.Statements...),
+ }
+}
+
+func (s *sqlGraph) toOrderedStatements() ([]Statement, error) {
+ vertices, err := (*graph.Graph[sqlVertex])(s).TopologicallySortWithPriority(graph.IsLowerPriorityFromGetPriority(
+ func(vertex sqlVertex) int {
+ multiplier := 1
+ if vertex.DiffType == diffTypeDelete {
+ multiplier = -1
+ }
+ // Prioritize adds/alters over deletes. Weight by number of statements. A 0 statement delete should be
+ // prioritized over a 1 statement delete
+ return len(vertex.Statements) * multiplier
+ }),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("topologically sorting graph: %w", err)
+ }
+ var stmts []Statement
+ for _, v := range vertices {
+ stmts = append(stmts, v.Statements...)
+ }
+ return stmts, nil
+}
diff --git a/pkg/diff/transform_diff.go b/pkg/diff/transform_diff.go
new file mode 100644
index 0000000..98ff5ca
--- /dev/null
+++ b/pkg/diff/transform_diff.go
@@ -0,0 +1,44 @@
+package diff
+
+import (
+ "sort"
+
+ "github.com/stripe/pg-schema-diff/internal/schema"
+)
+
+// dataPackNewTables packs the columns in new tables to minimize the space they occupy
+//
+// Note: We need to copy all arrays we modify because otherwise those arrays (effectively pointers)
+// will still exist in the original structs, leading to mutation. Go is copy-by-value, but all slices
+// are pointers. If we don't copy the arrays we change, then changing a struct in the array will mutate the
+// original SchemaDiff struct
+func dataPackNewTables(s schemaDiff) schemaDiff {
+ copiedNewTables := append([]schema.Table(nil), s.tableDiffs.adds...)
+ for i, _ := range copiedNewTables {
+ copiedColumns := append([]schema.Column(nil), copiedNewTables[i].Columns...)
+ copiedNewTables[i].Columns = copiedColumns
+ sort.Slice(copiedColumns, func(i, j int) bool {
+ // Sort in descending order of size
+ return copiedColumns[i].Size > copiedColumns[j].Size
+ })
+ }
+ s.tableDiffs.adds = copiedNewTables
+
+ return s
+}
+
+// removeChangesToColumnOrdering removes any changes to column ordering. In effect, it tells the SQL
+// generator to ignore changes to column ordering
+func removeChangesToColumnOrdering(s schemaDiff) schemaDiff {
+ copiedTableDiffs := append([]tableDiff(nil), s.tableDiffs.alters...)
+ for i, _ := range copiedTableDiffs {
+ copiedColDiffs := append([]columnDiff(nil), copiedTableDiffs[i].columnsDiff.alters...)
+ for i, _ := range copiedColDiffs {
+ copiedColDiffs[i].oldOrdering = copiedColDiffs[i].newOrdering
+ }
+ copiedTableDiffs[i].columnsDiff.alters = copiedColDiffs
+ }
+ s.tableDiffs.alters = copiedTableDiffs
+
+ return s
+}
diff --git a/pkg/diff/transform_diff_test.go b/pkg/diff/transform_diff_test.go
new file mode 100644
index 0000000..8c509b8
--- /dev/null
+++ b/pkg/diff/transform_diff_test.go
@@ -0,0 +1,354 @@
+package diff
+
+import (
+ "sort"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stripe/pg-schema-diff/internal/schema"
+)
+
+func TestTransformDiffDataPackNewTables(t *testing.T) {
+ tcs := []transformDiffTestCase{
+ {
+ name: "No new tables",
+ in: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{},
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ expectedOut: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{},
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ },
+ {
+ name: "New table with no columns",
+ in: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{
+ adds: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: nil,
+ CheckConstraints: nil,
+ },
+ },
+ },
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ expectedOut: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{
+ adds: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: nil,
+ CheckConstraints: nil,
+ },
+ },
+ },
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ },
+ {
+ name: "Standard",
+ in: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{
+ alters: []tableDiff{
+ buildTableDiffWithColDiffs("foo", []columnDiff{
+ buildColumnDiff(schema.Column{Name: "genre", Type: "some type", Size: 3}, 0, 0),
+ buildColumnDiff(schema.Column{Name: "content", Type: "some type", Size: 2}, 1, 1),
+ buildColumnDiff(schema.Column{Name: "title", Type: "some type", Size: 10}, 2, 2),
+ }),
+ },
+ adds: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "coffee", Type: "some type", Size: 3},
+ {Name: "mocha", Type: "some type", Size: 2},
+ {Name: "latte", Type: "some type", Size: 10},
+ },
+ CheckConstraints: nil,
+ },
+ {
+ Name: "baz",
+ Columns: []schema.Column{
+ {Name: "dog", Type: "some type", Size: 1},
+ {Name: "cat", Type: "some type", Size: 2},
+ {Name: "rabbit", Type: "some type", Size: 3},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ deletes: []schema.Table{
+ {
+ Name: "fizz",
+ Columns: []schema.Column{
+ {Name: "croissant", Type: "some type", Size: 3},
+ {Name: "bagel", Type: "some type", Size: 2},
+ {Name: "pastry", Type: "some type", Size: 10},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ },
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ expectedOut: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{
+ alters: []tableDiff{
+ buildTableDiffWithColDiffs("foo", []columnDiff{
+ buildColumnDiff(schema.Column{Name: "genre", Type: "some type", Size: 3}, 0, 0),
+ buildColumnDiff(schema.Column{Name: "content", Type: "some type", Size: 2}, 1, 1),
+ buildColumnDiff(schema.Column{Name: "title", Type: "some type", Size: 10}, 2, 2),
+ }),
+ },
+ adds: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "latte", Type: "some type", Size: 10},
+ {Name: "coffee", Type: "some type", Size: 3},
+ {Name: "mocha", Type: "some type", Size: 2},
+ },
+ CheckConstraints: nil,
+ },
+ {
+ Name: "baz",
+ Columns: []schema.Column{
+ {Name: "rabbit", Type: "some type", Size: 3},
+ {Name: "cat", Type: "some type", Size: 2},
+ {Name: "dog", Type: "some type", Size: 1},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ deletes: []schema.Table{
+ {
+ Name: "fizz",
+ Columns: []schema.Column{
+ {Name: "croissant", Type: "some type", Size: 3},
+ {Name: "bagel", Type: "some type", Size: 2},
+ {Name: "pastry", Type: "some type", Size: 10},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ },
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ },
+ }
+ runTestCases(t, dataPackNewTables, tcs)
+}
+
+func TestTransformDiffRemoveChangesToColumnOrdering(t *testing.T) {
+ tcs := []transformDiffTestCase{
+ {
+ name: "No altered tables",
+ in: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{},
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ expectedOut: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{},
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ },
+ {
+ name: "Altered table with no columns",
+ in: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{
+ alters: []tableDiff{buildTableDiffWithColDiffs("foobar", nil)},
+ },
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ expectedOut: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{
+ alters: []tableDiff{buildTableDiffWithColDiffs("foobar", nil)},
+ },
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ },
+ {
+ name: "Standard",
+ in: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{
+ alters: []tableDiff{
+ buildTableDiffWithColDiffs("foo", []columnDiff{
+ buildColumnDiff(schema.Column{Name: "genre", Type: "some type", Size: 3}, 0, 1),
+ buildColumnDiff(schema.Column{Name: "content", Type: "some type", Size: 2}, 1, 2),
+ buildColumnDiff(schema.Column{Name: "title", Type: "some type", Size: 10}, 2, 0),
+ }),
+ buildTableDiffWithColDiffs("bar", []columnDiff{
+ buildColumnDiff(schema.Column{Name: "item", Type: "some type", Size: 3}, 0, 2),
+ buildColumnDiff(schema.Column{Name: "type", Type: "some type", Size: 2}, 1, 1),
+ buildColumnDiff(schema.Column{Name: "color", Type: "some type", Size: 10}, 2, 0),
+ }),
+ },
+ adds: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "cold brew", Type: "some type", Size: 2},
+ {Name: "coffee", Type: "some type", Size: 3},
+ {Name: "mocha", Type: "some type", Size: 2},
+ {Name: "latte", Type: "some type", Size: 10},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ deletes: []schema.Table{
+ {
+ Name: "fizz",
+ Columns: []schema.Column{
+ {Name: "croissant", Type: "some type", Size: 3},
+ {Name: "bagel", Type: "some type", Size: 2},
+ {Name: "pastry", Type: "some type", Size: 10},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ },
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ expectedOut: schemaDiff{
+ oldAndNew: oldAndNew[schema.Schema]{},
+ tableDiffs: listDiff[schema.Table, tableDiff]{
+ alters: []tableDiff{
+ buildTableDiffWithColDiffs("foo", []columnDiff{
+ buildColumnDiff(schema.Column{Name: "genre", Type: "some type", Size: 3}, 1, 1),
+ buildColumnDiff(schema.Column{Name: "content", Type: "some type", Size: 2}, 2, 2),
+ buildColumnDiff(schema.Column{Name: "title", Type: "some type", Size: 10}, 0, 0),
+ }),
+ buildTableDiffWithColDiffs("bar", []columnDiff{
+ buildColumnDiff(schema.Column{Name: "item", Type: "some type", Size: 3}, 2, 2),
+ buildColumnDiff(schema.Column{Name: "type", Type: "some type", Size: 2}, 1, 1),
+ buildColumnDiff(schema.Column{Name: "color", Type: "some type", Size: 10}, 0, 0),
+ }),
+ },
+ adds: []schema.Table{
+ {
+ Name: "foobar",
+ Columns: []schema.Column{
+ {Name: "cold brew", Type: "some type", Size: 2},
+ {Name: "coffee", Type: "some type", Size: 3},
+ {Name: "mocha", Type: "some type", Size: 2},
+ {Name: "latte", Type: "some type", Size: 10},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ deletes: []schema.Table{
+ {
+ Name: "fizz",
+ Columns: []schema.Column{
+ {Name: "croissant", Type: "some type", Size: 3},
+ {Name: "bagel", Type: "some type", Size: 2},
+ {Name: "pastry", Type: "some type", Size: 10},
+ },
+ CheckConstraints: nil,
+ },
+ },
+ },
+ indexDiffs: listDiff[schema.Index, indexDiff]{},
+ functionDiffs: listDiff[schema.Function, functionDiff]{},
+ triggerDiffs: listDiff[schema.Trigger, triggerDiff]{},
+ },
+ },
+ }
+ runTestCases(t, removeChangesToColumnOrdering, tcs)
+}
+
+type transformDiffTestCase struct {
+ name string
+ in schemaDiff
+ expectedOut schemaDiff
+}
+
+func runTestCases(t *testing.T, transform func(diff schemaDiff) schemaDiff, tcs []transformDiffTestCase) {
+ for _, tc := range tcs {
+ t.Run(tc.name, func(t *testing.T) {
+ assert.Equal(t, tc.expectedOut, transform(tc.in))
+ })
+ }
+}
+
+func buildTableDiffWithColDiffs(name string, columnDiffs []columnDiff) tableDiff {
+ var oldColumns, newColumns []schema.Column
+ copiedColumnDiffs := append([]columnDiff(nil), columnDiffs...)
+ sort.Slice(copiedColumnDiffs, func(i, j int) bool {
+ return copiedColumnDiffs[i].oldOrdering < copiedColumnDiffs[j].oldOrdering
+ })
+ for _, colDiff := range copiedColumnDiffs {
+ oldColumns = append(oldColumns, colDiff.old)
+ }
+ sort.Slice(copiedColumnDiffs, func(i, j int) bool {
+ return copiedColumnDiffs[i].newOrdering < copiedColumnDiffs[j].newOrdering
+ })
+ for _, colDiff := range copiedColumnDiffs {
+ oldColumns = append(newColumns, colDiff.new)
+ }
+
+ return tableDiff{
+ oldAndNew: oldAndNew[schema.Table]{
+ old: schema.Table{
+ Name: name,
+ Columns: oldColumns,
+ CheckConstraints: nil,
+ },
+ new: schema.Table{
+ Name: name,
+ Columns: newColumns,
+ CheckConstraints: nil,
+ },
+ },
+ columnsDiff: listDiff[schema.Column, columnDiff]{
+ alters: columnDiffs,
+ },
+ checkConstraintDiff: listDiff[schema.CheckConstraint, checkConstraintDiff]{},
+ }
+}
+
+func buildColumnDiff(col schema.Column, oldOrdering, newOrdering int) columnDiff {
+ return columnDiff{
+ oldAndNew: oldAndNew[schema.Column]{
+ old: col,
+ new: col,
+ },
+ oldOrdering: oldOrdering,
+ newOrdering: newOrdering,
+ }
+}
diff --git a/pkg/log/logger.go b/pkg/log/logger.go
new file mode 100644
index 0000000..3c1a2a5
--- /dev/null
+++ b/pkg/log/logger.go
@@ -0,0 +1,24 @@
+package log
+
+import (
+ "fmt"
+ "log"
+)
+
+type (
+ Logger interface {
+ Errorf(msg string, args ...any)
+ }
+
+ simpleLogger struct{}
+)
+
+// SimpleLogger is a bare-bones implementation of the logging interface, e.g., used for testing
+func SimpleLogger() Logger {
+ return &simpleLogger{}
+}
+
+func (*simpleLogger) Errorf(msg string, args ...any) {
+ formattedMessage := fmt.Sprintf(msg, args...)
+ log.Println(fmt.Sprintf("[ERROR] %s", formattedMessage))
+}
diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go
new file mode 100644
index 0000000..f004070
--- /dev/null
+++ b/pkg/schema/schema.go
@@ -0,0 +1,25 @@
+package schema
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ internalschema "github.com/stripe/pg-schema-diff/internal/schema"
+)
+
+// GetPublicSchemaHash hash gets the hash of the "public" schema. It can be used to compare against the hash in the migration
+// plan to determine if it's still valid
+// We do not expose the Schema struct yet because it is subject to change, and we do not want folks depending on its API
+func GetPublicSchemaHash(ctx context.Context, conn *sql.Conn) (string, error) {
+ schema, err := internalschema.GetPublicSchema(ctx, conn)
+ if err != nil {
+ return "", fmt.Errorf("getting public schema: %w", err)
+ }
+ hash, err := schema.Hash()
+ if err != nil {
+ return "", fmt.Errorf("hashing schema: %w", err)
+ }
+
+ return hash, nil
+}
diff --git a/pkg/schema/schema_test.go b/pkg/schema/schema_test.go
new file mode 100644
index 0000000..577c975
--- /dev/null
+++ b/pkg/schema/schema_test.go
@@ -0,0 +1,101 @@
+package schema_test
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/stripe/pg-schema-diff/internal/pgengine"
+ "github.com/stripe/pg-schema-diff/pkg/schema"
+)
+
+type schemaTestSuite struct {
+ suite.Suite
+ pgEngine *pgengine.Engine
+}
+
+func (suite *schemaTestSuite) SetupSuite() {
+ engine, err := pgengine.StartEngine()
+ suite.Require().NoError(err)
+ suite.pgEngine = engine
+}
+
+func (suite *schemaTestSuite) TearDownSuite() {
+ suite.pgEngine.Close()
+}
+
+func (suite *schemaTestSuite) TestGetPublicSchemaHash() {
+ const (
+ ddl = `
+ CREATE FUNCTION add(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN a + b;
+
+ CREATE FUNCTION increment(i integer) RETURNS integer AS $$
+ BEGIN
+ RETURN i + 1;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer
+ LANGUAGE SQL
+ IMMUTABLE
+ RETURNS NULL ON NULL INPUT
+ RETURN add(a, b) + increment(a);
+
+ CREATE TABLE foo (
+ id INTEGER PRIMARY KEY,
+ author TEXT COLLATE "C",
+ content TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP CHECK (created_at > CURRENT_TIMESTAMP - interval '1 month') NO INHERIT,
+ version INT NOT NULL DEFAULT 0,
+ CHECK ( function_with_dependencies(id, id) > 0)
+ );
+
+ ALTER TABLE foo ADD CONSTRAINT author_check CHECK (author IS NOT NULL AND LENGTH(author) > 0) NO INHERIT NOT VALID;
+ CREATE INDEX some_idx ON foo USING hash (content);
+ CREATE UNIQUE INDEX some_unique_idx ON foo (created_at DESC, author ASC);
+
+ CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.version = OLD.version + 1;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ CREATE TRIGGER some_trigger
+ BEFORE UPDATE ON foo
+ FOR EACH ROW
+ WHEN (OLD.* IS DISTINCT FROM NEW.*)
+ EXECUTE PROCEDURE increment_version();
+ `
+
+ expectedHash = "9648c294aed76ef6"
+ )
+ db, err := suite.pgEngine.CreateDatabase()
+ suite.Require().NoError(err)
+ defer db.DropDB()
+
+ connPool, err := sql.Open("pgx", db.GetDSN())
+ suite.Require().NoError(err)
+ defer connPool.Close()
+
+ _, err = connPool.ExecContext(context.Background(), ddl)
+ suite.Require().NoError(err)
+
+ conn, err := connPool.Conn(context.Background())
+ suite.Require().NoError(err)
+ defer conn.Close()
+
+ hash, err := schema.GetPublicSchemaHash(context.Background(), conn)
+ suite.Require().NoError(err)
+
+ suite.Equal(expectedHash, hash)
+}
+
+func TestSchemaTestSuite(t *testing.T) {
+ suite.Run(t, new(schemaTestSuite))
+}
diff --git a/pkg/tempdb/factory.go b/pkg/tempdb/factory.go
new file mode 100644
index 0000000..e502bb4
--- /dev/null
+++ b/pkg/tempdb/factory.go
@@ -0,0 +1,208 @@
+package tempdb
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "io"
+ "strings"
+
+ "github.com/google/uuid"
+ "github.com/jackc/pgx/v4"
+ _ "github.com/jackc/pgx/v4/stdlib"
+ "github.com/stripe/pg-schema-diff/internal/pgidentifier"
+ "github.com/stripe/pg-schema-diff/pkg/log"
+)
+
+const (
+ DefaultOnInstanceDbPrefix = "pgschemadifftmp_"
+ DefaultOnInstanceMetadataSchema = "pgschemadifftmp_metadata"
+ DefaultOnInstanceMetadataTable = "metadata"
+)
+
+// Factory is used to create temp databases These databases do not have to be in-memory. They might be, for example,
+// be created on the target Postgres server
+type (
+ Dropper func(ctx context.Context) error
+
+ Factory interface {
+ // Create creates a temporary database. Be sure to always call the Dropper to ensure the database and
+ // connections are cleaned up
+ Create(ctx context.Context) (db *sql.DB, dropper Dropper, err error)
+
+ io.Closer
+ }
+)
+
+type (
+ onInstanceFactoryOptions struct {
+ dbPrefix string
+ metadataSchema string
+ metadataTable string
+ logger log.Logger
+ }
+
+ OnInstanceFactoryOpt func(*onInstanceFactoryOptions)
+)
+
+// WithLogger sets the logger for the factory. If not set, a SimpleLogger will be used
+func WithLogger(logger log.Logger) OnInstanceFactoryOpt {
+ return func(opts *onInstanceFactoryOptions) {
+ opts.logger = logger
+ }
+}
+
+// WithDbPrefix sets the prefix for the temp database name
+func WithDbPrefix(prefix string) OnInstanceFactoryOpt {
+ return func(opts *onInstanceFactoryOptions) {
+ opts.dbPrefix = prefix
+ }
+}
+
+// WithMetadataSchema sets the prefix for the schema name containing the metadata
+func WithMetadataSchema(schema string) OnInstanceFactoryOpt {
+ return func(opts *onInstanceFactoryOptions) {
+ opts.metadataSchema = schema
+ }
+}
+
+// WithMetadataTable sets the metadata table name
+func WithMetadataTable(table string) OnInstanceFactoryOpt {
+ return func(opts *onInstanceFactoryOptions) {
+ opts.metadataTable = table
+ }
+}
+
+type (
+ CreateConnForDbFn func(ctx context.Context, dbName string) (*sql.DB, error)
+
+ // onInstanceFactory creates temporary databases on the provided Postgres server
+ onInstanceFactory struct {
+ rootDb *sql.DB
+ createConnForDb CreateConnForDbFn
+ options onInstanceFactoryOptions
+ }
+)
+
+// NewOnInstanceFactory provides an implementation to easily create temporary databases on the Postgres instance
+// connected to via CreateConnForDbFn. The Postgres instance is connected to via the "postgres" database, and then
+// temporary databases are created using that connection. These temporary databases are also connected to via the
+// CreateConnForDbFn.
+// Make sure to always call Close() on the returned Factory to ensure the root connection is closed
+//
+// WARNING:
+// It is possible this will lead to orphaned temporary databases. These orphaned databases should be pretty small if
+// they're only being used by the pg-schema-diff library, but it's recommended to clean them up when possible. This can
+// be done by deleting all old databases with the provided temp db prefix. The metadata table can be inspected to find
+// when the temporary database was created, e.g., to create a TTL
+func NewOnInstanceFactory(ctx context.Context, createConnForDb CreateConnForDbFn, opts ...OnInstanceFactoryOpt) (Factory, error) {
+ options := onInstanceFactoryOptions{
+ dbPrefix: DefaultOnInstanceDbPrefix,
+ metadataSchema: DefaultOnInstanceMetadataSchema,
+ metadataTable: DefaultOnInstanceMetadataTable,
+ logger: log.SimpleLogger(),
+ }
+ for _, opt := range opts {
+ opt(&options)
+ }
+ if !pgidentifier.IsSimpleIdentifier(options.dbPrefix) {
+ return nil, fmt.Errorf("dbPrefix (%s) must be a simple Postgres identifier matching the following regex: %s", options.dbPrefix, pgidentifier.SimpleIdentifierRegex)
+ }
+
+ rootDb, err := createConnForDb(ctx, "postgres")
+ if err != nil {
+ return &onInstanceFactory{}, err
+ }
+ if err := assertConnPoolIsOnExpectedDatabase(ctx, rootDb, "postgres"); err != nil {
+ rootDb.Close()
+ return &onInstanceFactory{}, fmt.Errorf("assertConnPoolIsOnExpectedDatabase: %w", err)
+ }
+
+ return &onInstanceFactory{
+ rootDb: rootDb,
+ createConnForDb: createConnForDb,
+ options: options,
+ }, nil
+}
+
+func (o *onInstanceFactory) Close() error {
+ return o.rootDb.Close()
+}
+
+func (o *onInstanceFactory) Create(ctx context.Context) (sql *sql.DB, dropper Dropper, retErr error) {
+ dbUUID, err := uuid.NewUUID()
+ if err != nil {
+ return nil, nil, err
+ }
+ tempDbName := o.options.dbPrefix + strings.ReplaceAll(dbUUID.String(), "-", "_")
+ if _, err = o.rootDb.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s;", tempDbName)); err != nil {
+ return nil, nil, err
+ }
+ defer func() {
+ // Only drop the temp database if an error occurred during creation
+ if retErr != nil {
+ if err := o.dropTempDatabase(ctx, tempDbName); err != nil {
+ o.options.logger.Errorf("Failed to drop temporary database %s because of error %s. This drop was automatically triggered by error %s", tempDbName, err.Error(), retErr.Error())
+ }
+ }
+ }()
+
+ tempDbConn, err := o.createConnForDb(ctx, tempDbName)
+ if err != nil {
+ return nil, nil, err
+ }
+ defer func() {
+ // Only close the connection pool if an error occurred during creation.
+ // We should close the connection pool on the off-chance that the drop database fails
+ if retErr != nil {
+ _ = tempDbConn.Close()
+ }
+ }()
+ if err := assertConnPoolIsOnExpectedDatabase(ctx, tempDbConn, tempDbName); err != nil {
+ return nil, nil, fmt.Errorf("assertConnPoolIsOnExpectedDatabase: %w", err)
+ }
+
+ // There's no easy way to keep track of when a database was created, so
+ // create a row with the temp database's time of creation. This will be used
+ // by a cleanup process to drop any temporary databases that were not cleaned up
+ // successfully by the "dropper"
+ sanitizedSchemaName := pgx.Identifier{o.options.metadataSchema}.Sanitize()
+ sanitizedTableName := pgx.Identifier{o.options.metadataSchema, o.options.metadataTable}.Sanitize()
+ createMetadataStmts := fmt.Sprintf(`
+ CREATE SCHEMA %s
+ CREATE TABLE %s(
+ db_created_at TIMESTAMPTZ NOT NULL DEFAULT current_timestamp
+ );
+ INSERT INTO %s DEFAULT VALUES;
+ `, sanitizedSchemaName, sanitizedTableName, sanitizedTableName)
+ if _, err := tempDbConn.ExecContext(ctx, createMetadataStmts); err != nil {
+ return nil, nil, err
+ }
+
+ return tempDbConn, func(ctx context.Context) error {
+ _ = tempDbConn.Close()
+ return o.dropTempDatabase(ctx, tempDbName)
+ }, nil
+
+}
+
+// assertConnPoolIsOnExpectedDatabase provides validation that a user properly passed in a proper CreateConnForDbFn
+func assertConnPoolIsOnExpectedDatabase(ctx context.Context, connPool *sql.DB, expectedDatabase string) error {
+ var dbName string
+ if err := connPool.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbName); err != nil {
+ return err
+ }
+ if dbName != expectedDatabase {
+ return fmt.Errorf("connection pool is on database %s, expected %s", dbName, expectedDatabase)
+ }
+
+ return nil
+}
+
+func (o *onInstanceFactory) dropTempDatabase(ctx context.Context, dbName string) error {
+ if !strings.HasPrefix(dbName, o.options.dbPrefix) {
+ return fmt.Errorf("drop non-temporary database: %s", dbName)
+ }
+ _, err := o.rootDb.ExecContext(ctx, fmt.Sprintf("DROP DATABASE %s;", dbName))
+ return err
+}
diff --git a/pkg/tempdb/factory_test.go b/pkg/tempdb/factory_test.go
new file mode 100644
index 0000000..f02fddb
--- /dev/null
+++ b/pkg/tempdb/factory_test.go
@@ -0,0 +1,190 @@
+package tempdb
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ _ "github.com/jackc/pgx/v4/stdlib"
+ "github.com/stretchr/testify/suite"
+ "github.com/stripe/pg-schema-diff/internal/pgengine"
+
+ "github.com/stripe/pg-schema-diff/pkg/log"
+)
+
+type onInstanceTempDbFactorySuite struct {
+ suite.Suite
+
+ engine *pgengine.Engine
+}
+
+func (suite *onInstanceTempDbFactorySuite) SetupSuite() {
+ engine, err := pgengine.StartEngine()
+ suite.Require().NoError(err)
+ suite.engine = engine
+}
+
+func (suite *onInstanceTempDbFactorySuite) TearDownSuite() {
+ suite.engine.Close()
+}
+
+func (suite *onInstanceTempDbFactorySuite) mustBuildFactory(opt ...OnInstanceFactoryOpt) Factory {
+ factory, err := NewOnInstanceFactory(context.Background(), func(ctx context.Context, dbName string) (*sql.DB, error) {
+ return suite.getConnPoolForDb(dbName)
+ }, opt...)
+ suite.Require().NoError(err)
+ return factory
+}
+
+func (suite *onInstanceTempDbFactorySuite) getConnPoolForDb(dbName string) (*sql.DB, error) {
+ return sql.Open("pgx", suite.engine.GetPostgresDatabaseConnOpts().With("dbname", dbName).ToDSN())
+}
+
+func (suite *onInstanceTempDbFactorySuite) mustRunSQL(conn *sql.Conn) {
+ _, err := conn.ExecContext(context.Background(), `
+ CREATE TABLE foobar(
+ id INT PRIMARY KEY,
+ message TEXT
+ );
+ CREATE INDEX message_idx ON foobar(message);
+ `)
+ suite.Require().NoError(err)
+
+ _, err = conn.ExecContext(context.Background(), `
+ INSERT INTO foobar VALUES (1, 'some message'), (2, 'some other message'), (3, 'a final message');
+ `)
+ suite.Require().NoError(err)
+
+ res, err := conn.QueryContext(context.Background(), `
+ SELECT id, message FROM foobar;
+ `)
+ suite.Require().NoError(err)
+
+ var rows [][]any
+ for res.Next() {
+ var id int
+ var message string
+ suite.Require().NoError(res.Scan(&id, &message))
+ rows = append(rows, []any{
+ id, message,
+ })
+ }
+ suite.ElementsMatch([][]any{
+ {1, "some message"},
+ {2, "some other message"},
+ {3, "a final message"},
+ }, rows)
+}
+
+func (suite *onInstanceTempDbFactorySuite) TestNew_ConnectsToWrongDatabase() {
+ db, err := suite.engine.CreateDatabaseWithName("not-postgres")
+ suite.Require().NoError(err)
+ defer db.DropDB()
+
+ _, err = NewOnInstanceFactory(context.Background(), func(ctx context.Context, dbName string) (*sql.DB, error) {
+ return suite.getConnPoolForDb("not-postgres")
+ })
+ suite.ErrorContains(err, "connection pool is on")
+}
+
+func (suite *onInstanceTempDbFactorySuite) TestNew_ErrorsOnNonSimpleDbPrefix() {
+ _, err := NewOnInstanceFactory(context.Background(), func(ctx context.Context, dbName string) (*sql.DB, error) {
+ suite.Fail("shouldn't be reached")
+ return nil, nil
+ }, WithDbPrefix("non-simple identifier"))
+ suite.ErrorContains(err, "must be a simple Postgres identifier")
+}
+
+func (suite *onInstanceTempDbFactorySuite) TestCreate_CreateAndDropFlow() {
+ const (
+ dbPrefix = "some_prefix"
+ metadataSchema = "some metadata schema"
+ metadataTable = "some metadata table"
+ )
+ factory := suite.mustBuildFactory(
+ WithDbPrefix(dbPrefix),
+ WithMetadataSchema(metadataSchema),
+ WithMetadataTable(metadataTable),
+ WithLogger(log.SimpleLogger()),
+ )
+ defer func(factory Factory) {
+ suite.Require().NoError(factory.Close())
+ }(factory)
+
+ db, dropper, err := factory.Create(context.Background())
+ suite.Require().NoError(err)
+ // don't defer dropping. we want to run assertions after it drops. if dropping fails,
+ // it shouldn't be a problem because names shouldn't conflict
+ afterTimeOfCreation := time.Now()
+
+ conn1, err := db.Conn(context.Background())
+ suite.Require().NoError(err)
+
+ var dbName string
+ suite.Require().NoError(conn1.QueryRowContext(context.Background(), "SELECT current_database()").Scan(&dbName))
+ suite.True(strings.HasPrefix(dbName, dbPrefix))
+ suite.Len(dbName, len(dbPrefix)+36) // should be length of prefix + length of uuid
+
+ // Make sure SQL can run on the connection
+ suite.mustRunSQL(conn1)
+
+ // check the metadata entry exists
+ var createdAt time.Time
+ metadataQuery := fmt.Sprintf(`
+ SELECT * FROM "%s"."%s"
+ `, metadataSchema, metadataTable)
+ suite.Require().NoError(conn1.QueryRowContext(context.Background(), metadataQuery).Scan(&createdAt))
+ suite.True(createdAt.Before(afterTimeOfCreation))
+
+ // get another connection from the pool and make sure it's also set to the correct db while
+ // the other connection is still open
+ conn2, err := db.Conn(context.Background())
+ suite.Require().NoError(err)
+ var dbNameFromConn2 string
+ suite.Require().NoError(conn2.QueryRowContext(context.Background(), "SELECT current_database()").Scan(&dbNameFromConn2))
+ suite.Equal(dbName, dbNameFromConn2)
+
+ suite.Require().NoError(conn1.Close())
+ suite.Require().NoError(conn2.Close())
+
+ // drop database
+ suite.Require().NoError(db.Close())
+ suite.Require().NoError(dropper(context.Background()))
+
+ // expect an error when attempting to query the database, since it should be dropped.
+ // when a db pool is opened, it has no connections.
+ // a query is needed in order to find if the database still exists.
+ conn, err := suite.getConnPoolForDb(dbName)
+ suite.Require().NoError(err)
+ suite.Require().ErrorContains(conn.QueryRowContext(context.Background(), metadataQuery).Scan(&createdAt), "SQLSTATE 3D000")
+ suite.True(createdAt.Before(afterTimeOfCreation))
+}
+
+func (suite *onInstanceTempDbFactorySuite) TestCreate_ConnectsToWrongDatabase() {
+ factory, err := NewOnInstanceFactory(context.Background(), func(ctx context.Context, dbName string) (*sql.DB, error) {
+ return suite.getConnPoolForDb("postgres")
+ })
+ suite.Require().NoError(err)
+ defer func(factory Factory) {
+ suite.Require().NoError(factory.Close())
+ }(factory)
+
+ _, _, err = factory.Create(context.Background())
+ suite.ErrorContains(err, "connection pool is on")
+}
+
+func (suite *onInstanceTempDbFactorySuite) TestDropTempDB_CannotDropNonTempDb() {
+ factory := suite.mustBuildFactory()
+ defer func(factory Factory) {
+ suite.Require().NoError(factory.Close())
+ }(factory)
+
+ suite.ErrorContains(factory.(*onInstanceFactory).dropTempDatabase(context.Background(), "some_db"), "drop non-temporary database")
+}
+
+func TestOnInstanceFactorySuite(t *testing.T) {
+ suite.Run(t, new(onInstanceTempDbFactorySuite))
+}
diff --git a/scripts/codegen.sh b/scripts/codegen.sh
new file mode 100755
index 0000000..95af2f3
--- /dev/null
+++ b/scripts/codegen.sh
@@ -0,0 +1,4 @@
+#!/bin/sh
+
+docker build -t pg-schema-diff-code-gen-runner -f ./build/Dockerfile.codegen .
+docker run --rm -v $(pwd):/pg-schema-diff -w /pg-schema-diff pg-schema-diff-code-gen-runner