diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..e90e11a --- /dev/null +++ b/.editorconfig @@ -0,0 +1,22 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +[{*.go,Makefile,.gitmodules,go.mod,go.sum}] +indent_style = tab + +[*.md] +indent_style = tab +trim_trailing_whitespace = false + +[*.{yml,yaml,json}] +indent_style = space +indent_size = 2 + +[*.{js,jsx,ts,tsx,css,less,sass,scss,vue,py}] +indent_style = space +indent_size = 4 diff --git a/.github/workflows/code-gen.yaml b/.github/workflows/code-gen.yaml new file mode 100644 index 0000000..1fdd72d --- /dev/null +++ b/.github/workflows/code-gen.yaml @@ -0,0 +1,23 @@ +name: code_gen +on: + push: + branches: + - main + pull_request: +permissions: + contents: read +jobs: + code_gen: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Build Docker image + run: docker build -t pg-schema-diff-code-gen-runner -f ./build/Dockerfile.codegen . + - name: Run codegen + run: docker run -v $(pwd):/pg-schema-diff -w /pg-schema-diff pg-schema-diff-code-gen-runner + - name: Check for changes + run: | + chmod +x ./build/ci-scripts/assert-no-diff.sh + ./build/ci-scripts/assert-no-diff.sh + shell: bash + diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..44dfa27 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,21 @@ +name: lint +on: + push: + branches: + - main + pull_request: +permissions: + contents: read +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/setup-go@v4 + with: + go-version: '1.18' + cache: false + - uses: actions/checkout@v3 + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: v1.52.2 diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml new file mode 100644 index 0000000..80a806f --- /dev/null +++ b/.github/workflows/run-tests.yaml @@ -0,0 +1,20 @@ +name: run_tests + +on: + push: + branches: + - main + pull_request: +jobs: + run_tests: + runs-on: ubuntu-latest + strategy: + matrix: + pg_version: [14, 15] + steps: + - name: Checkout code + uses: actions/checkout@v3 + - name: Build Docker image + run: docker build -t pg-schema-diff-test-runner -f ./build/Dockerfile.test --build-arg POSTGRES_PACKAGE=postgresql${{ matrix.pg_version }} . + - name: Run tests + run: docker run pg-schema-diff-test-runner diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..550b268 --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +# Mac OS X files +.DS_Store + +# Binaries for programs and plugins +bin/ +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test +!/build/Dockerfile.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ + +# Dependency directories +vendor/ + +# Intellij +.idea/ diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..0fbdc18 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,9 @@ +linters: + disable-all: true + enable: + - goimports + - ineffassign + - staticcheck + - typecheck + - unused + diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..25c4774 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,76 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at conduct@stripe.com. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..c113636 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,24 @@ +# Contributing + +This project is in its early stages. We appreciate all the feature/bug requests we receive, but we have limited cycles +to review direct code contributions at this time. We will try and respond to any bug reports, feature requests, and +questions within one week. + +If you want to make changes yourself, follow these steps: + +1. [Fork](https://help.github.com/articles/fork-a-repo/) this repository and [clone](https://help.github.com/articles/cloning-a-repository/) it locally. +2. Make your changes +3. Test your changes +```bash + docker build -t pg-schema-diff-test-runner -f ./build/Dockerfile.test --build-arg POSTGRES_PACKAGE=postgresql{14, 15} . + ``` +3. Submit a [pull request](https://help.github.com/articles/creating-a-pull-request-from-a-fork/) + +## Contributor License Agreement ([CLA](https://en.wikipedia.org/wiki/Contributor_License_Agreement)) + +Once you have submitted a pull request, sign the CLA by clicking on the badge in the comment from [@CLAassistant](https://github.com/CLAassistant). + +image + +
+Thanks for contributing to Stripe! diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..34f864a --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) 2023- Stripe, Inc. (https://stripe.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b75b59e --- /dev/null +++ b/Makefile @@ -0,0 +1,16 @@ +.PHONY: code_gen format lint test vendor sqlc + +code_gen: sqlc + +format: + goimports -w . + +lint: + golangci-lint run + +sqlc: + cd internal/queries && sqlc generate + +vendor: + go mod vendor + diff --git a/README.md b/README.md new file mode 100644 index 0000000..ecdb281 --- /dev/null +++ b/README.md @@ -0,0 +1,104 @@ +# pg-schema-diff + +Diffs Postgres database schemas and generates the the SQL required to get your database schema from point A to B. This +enables you to take your database and migrate it to any desired schema defined with plain DDL. + +The tooling attempts to use native postgres migration operations and avoid locking wherever possible. Not all migrations will +be lock-free and some might require downtime, but the hazards system will warn you ahead of time when that's the case. +Stateful online migration techniques, like shadow tables, aren't yet supported. + +``` +pg-schema-diff plan --dsn "postgres://postgres:postgres@localhost:5432/postgres" --schema-dir schema + +################################ Generated plan ################################ +1. ALTER TABLE "foobar" ADD COLUMN "fizz" character varying(255) COLLATE "pg_catalog"."default"; + -- Timeout: 3s + +2. CREATE INDEX CONCURRENTLY fizz_idx ON public.foobar USING btree (fizz); + -- Timeout: 20m0s + -- Hazard INDEX_BUILD: This might affect database performance. Concurrent index builds require a non-trivial amount of CPU, potentially affecting database performance. They also can take a while but do not lock out writes. +``` + +# Key features +*Broad support for diffing & applying arbitrary postgres schemas defined in declarative DDL:* +- Tables +- Columns +- Check Constraints +- Indexes +- Partitions +- Functions/Triggers (functions created by extensions are ignored) + +*A comprehensive set of features to ensure the safety of planned migrations:* +- Dangerous operations are flagged as hazards and must be approved before a migration can be applied. + - Data deletion hazards identify operations which will in some way delete or alter data. + - Downtime/locking hazards identify operations which will impede or stop other queries. + - Performance hazards identify operations which are resource intensive and might slow other queries. +- Migration plans are validated first against a temporary database exactly as they would be performed against the real database. +- The library is tested against an extensive suite of unit and acceptance tests. + +*The use of postgres native operations for zero-downtime migrations wherever possible:* +- Concurrent index builds +- Online index replacement + +# Install +## Library +```bash +go get -u github.com/stripe/pg-schema-diff +```` +## CLI +```bash +go install github.com/stripe/pg-schema-diff/cmd/pg-schema-diff +``` + +# Using CLI +## 1. Apply schema to fresh database +Create a directory to hold your schema files. Then, generate sql files and place them into a schema dir. +```bash +mkdir schema +echo "CREATE TABLE foobar (id int);" > schema/foobar.sql +echo "CREATE TABLE bar (id varchar(255), message TEXT NOT NULL);" > schema/bar.sql +``` + +Apply the schema to a fresh database. [The connection string spec can be found here](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING). +Setting the `PGPASSWORD` env var will override any password set in the connection string and is recommended. +```bash +pg-schema-diff apply --dsn "postgres://postgres:postgres@localhost:5432/postgres" --schema-dir schema +``` + +## 2. Updating schema +Update the SQL file(s) +```bash +echo "CREATE INDEX message_idx ON bar(message)" >> schema/bar.sql +``` + +Apply the schema. Any hazards in the generated plan must be approved +```bash +pg-schema-diff apply --dsn "postgres://postgres:postgres@localhost:5432/postgres" --schema-dir schema --allow-hazards INDEX_BUILD +``` + +# Supported Postgres versions +- 14 (tested with 14.7) +- 15 (tested with 15.2) + +Postgres v13 and below are not supported. Use at your own risk. + +# Unsupported migrations +Note, the library only currently supports diffing the *public* schema. Support for diffing other schemas is on the roadmap + +*Unsupported*: +- (On roadmap) Foreign key constraints +- (On roadmap) Diffing schemas other than "public" +- (On roadmap) Serials and sequences +- (On roadmap) Unique constraints (unique indexes are supported but not unique constraints) +- (On roadmap) Adding and remove partitions from an existing partitioned table +- (On roadmap) Check constraints localized to specific partitions +- Partitioned partitions (partitioned tables are supported but not partitioned partitions) +- Materialized views +- Renaming. The diffing library relies on names to identify the old and new versions of a table, index, etc. If you rename +an object, it will be treated as a drop and an add + +# Contributing +This project is in its early stages. We appreciate all the feature/bug requests we receive, but we have limited cycles +to review direct code contributions at this time. See [Contributing](CONTRIBUTING.md) to learn more. + + diff --git a/build/Dockerfile.codegen b/build/Dockerfile.codegen new file mode 100644 index 0000000..0940591 --- /dev/null +++ b/build/Dockerfile.codegen @@ -0,0 +1,10 @@ +FROM golang:1.18.10-alpine3.17 + +RUN apk update && \ + apk add --no-cache \ + build-base \ + git \ + make + +RUN go install github.com/kyleconroy/sqlc/cmd/sqlc@v1.13.0 +ENTRYPOINT make code_gen diff --git a/build/Dockerfile.test b/build/Dockerfile.test new file mode 100644 index 0000000..68c292c --- /dev/null +++ b/build/Dockerfile.test @@ -0,0 +1,26 @@ +FROM golang:1.18.10-alpine3.17 + +ARG POSTGRES_PACKAGE + +RUN apk update && \ + apk add --no-cache \ + build-base \ + make \ + $POSTGRES_PACKAGE \ + postgresql-contrib \ + postgresql-client + +WORKDIR /pg-schema-diff + +COPY . . + +# Download dependencies so they are cached in a layer +RUN go mod download + +# Run all tests from non-root. This will also prevent Postgres from complaining when +# we try to launch it within tests +RUN adduser --disabled-password --gecos '' testrunner +USER testrunner + +# Run tests serially so logs can be streamed. Set overall timeout to 30m (the default is 10m, which is not enough) +ENTRYPOINT ["go", "test", "-v", "-race", "-p", "1", "./...", "-timeout", "30m"] diff --git a/build/ci-scripts/assert-no-diff.sh b/build/ci-scripts/assert-no-diff.sh new file mode 100755 index 0000000..b6bb518 --- /dev/null +++ b/build/ci-scripts/assert-no-diff.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +git_status=$(git status --porcelain) +if [[ -n $git_status ]]; then + echo "Changes to generated files detected $git_status" + exit 1 +fi diff --git a/cmd/pg-schema-diff/apply_cmd.go b/cmd/pg-schema-diff/apply_cmd.go new file mode 100644 index 0000000..24eda9f --- /dev/null +++ b/cmd/pg-schema-diff/apply_cmd.go @@ -0,0 +1,173 @@ +package main + +import ( + "context" + "errors" + "fmt" + "sort" + "strings" + "time" + + "github.com/jackc/pgx/v4" + "github.com/spf13/cobra" + "github.com/stripe/pg-schema-diff/pkg/diff" + "github.com/stripe/pg-schema-diff/pkg/log" +) + +func buildApplyCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "apply", + Short: "Migrate your database to the match the inputted schema (apply the schema to the database)", + } + + connFlags := createConnFlags(cmd) + planFlags := createPlanFlags(cmd) + allowedHazardsTypesStrs := cmd.Flags().StringSlice("allow-hazards", nil, + "Specify the hazards that are allowed. Order does not matter, and duplicates are ignored. If the"+ + " migration plan contains unwanted hazards (hazards not in this list), then the migration will fail to run"+ + " (example: --allowed-hazards DELETES_DATA,INDEX_BUILD)") + lockTimeout := cmd.Flags().Duration("lock-timeout", 30*time.Second, "the max time to wait to acquire a lock. 0 implies no timeout") + cmd.RunE = func(cmd *cobra.Command, args []string) error { + connConfig, err := connFlags.parseConnConfig() + if err != nil { + return err + } + + planConfig, err := planFlags.parsePlanConfig() + if err != nil { + return err + } + + if *lockTimeout < 0 { + return errors.New("lock timeout must be >= 0") + } + + cmd.SilenceUsage = true + + plan, err := generatePlan(context.Background(), log.SimpleLogger(), connConfig, planConfig) + if err != nil { + return err + } else if len(plan.Statements) == 0 { + fmt.Println("Schema matches expected. No plan generated") + return nil + } + + fmt.Println(header("Review plan")) + fmt.Print(planToPrettyS(plan), "\n\n") + + if err := failIfHazardsNotAllowed(plan, *allowedHazardsTypesStrs); err != nil { + return err + } + if err := mustContinuePrompt( + fmt.Sprintf( + "Apply migration with the following hazards: %s?", + strings.Join(*allowedHazardsTypesStrs, ", "), + ), + ); err != nil { + return err + } + + if err := runPlan(context.Background(), connConfig, plan, lockTimeout); err != nil { + return err + } + fmt.Println("Schema applied successfully") + return nil + } + + return cmd +} + +func failIfHazardsNotAllowed(plan diff.Plan, allowedHazardsTypesStrs []string) error { + isAllowedByHazardType := make(map[diff.MigrationHazardType]bool) + for _, val := range allowedHazardsTypesStrs { + isAllowedByHazardType[strings.ToUpper(val)] = true + } + var disallowedHazardMsgs []string + for i, stmt := range plan.Statements { + var disallowedTypes []diff.MigrationHazardType + for _, hzd := range stmt.Hazards { + if !isAllowedByHazardType[hzd.Type] { + disallowedTypes = append(disallowedTypes, hzd.Type) + } + } + if len(disallowedTypes) > 0 { + disallowedHazardMsgs = append(disallowedHazardMsgs, + fmt.Sprintf("- Statement %d: %s", getDisplayableStmtIdx(i), strings.Join(disallowedTypes, ", ")), + ) + } + + } + if len(disallowedHazardMsgs) > 0 { + return errors.New(fmt.Sprintf( + "Prohited hazards found\n"+ + "These hazards must be allowed via the allow-hazards flag, e.g., --allow-hazards %s\n"+ + "Prohibited hazards in the following statements:\n%s", + strings.Join(getHazardTypes(plan), ","), + strings.Join(disallowedHazardMsgs, "\n"), + )) + } + return nil +} + +func runPlan(ctx context.Context, connConfig *pgx.ConnConfig, plan diff.Plan, lockTimeout *time.Duration) error { + connPool, err := openDbWithPgxConfig(connConfig) + if err != nil { + return err + } + defer connPool.Close() + + conn, err := connPool.Conn(ctx) + if err != nil { + return err + } + defer conn.Close() + + _, err = conn.ExecContext(ctx, fmt.Sprintf("SET SESSION lock_timeout = %d", lockTimeout.Milliseconds())) + if err != nil { + return fmt.Errorf("setting lock timeout: %w", err) + } + + // Due to the way *sql.Db works, when a statement_timeout is set for the session, it will NOT reset + // by default when it's returned to the pool. + // + // We can't set the timeout at the TRANSACTION-level (for each transaction) because `ADD INDEX CONCURRENTLY` + // must be executed within its own transaction block. Postgres will error if you try to set a TRANSACTION-level + // timeout for it. SESSION-level statement_timeouts are respected by `ADD INDEX CONCURRENTLY` + for i, stmt := range plan.Statements { + fmt.Println(header(fmt.Sprintf("Executing statement %d", getDisplayableStmtIdx(i)))) + fmt.Printf("%s\n\n", statementToPrettyS(stmt)) + start := time.Now() + if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION statement_timeout = %d", stmt.Timeout.Milliseconds())); err != nil { + return fmt.Errorf("setting statement timeout: %w", err) + } + if _, err := conn.ExecContext(ctx, stmt.ToSQL()); err != nil { + // could the migration statement contain sensitive information? + return fmt.Errorf("executing migration statement. the database maybe be in a dirty state: %s: %w", stmt, err) + } + fmt.Printf("Finished executing statement. Duration: %s\n", time.Since(start)) + } + fmt.Println(header("Complete")) + + return nil +} + +func getHazardTypes(plan diff.Plan) []diff.MigrationHazardType { + seenHazardTypes := make(map[diff.MigrationHazardType]bool) + var hazardTypes []diff.MigrationHazardType + for _, stmt := range plan.Statements { + for _, hazard := range stmt.Hazards { + if !seenHazardTypes[hazard.Type] { + seenHazardTypes[hazard.Type] = true + hazardTypes = append(hazardTypes, hazard.Type) + } + } + } + sort.Slice(hazardTypes, func(i, j int) bool { + return hazardTypes[i] < hazardTypes[j] + }) + return hazardTypes +} + +func getDisplayableStmtIdx(i int) int { + return i + 1 +} diff --git a/cmd/pg-schema-diff/cli.go b/cmd/pg-schema-diff/cli.go new file mode 100644 index 0000000..edd447d --- /dev/null +++ b/cmd/pg-schema-diff/cli.go @@ -0,0 +1,46 @@ +package main + +import ( + "fmt" + "math" + "strings" + + "github.com/manifoldco/promptui" +) + +func header(header string) string { + const headerTargetWidth = 80 + + if len(header) > headerTargetWidth { + return header + } + + if len(header) > 0 { + header = fmt.Sprintf(" %s ", header) + } + hashTagsOnSide := int(math.Ceil(float64(headerTargetWidth-len(header)) / 2)) + + rightHashTags := strings.Repeat("#", hashTagsOnSide) + leftHashTags := rightHashTags + if headerTargetWidth-len(header)-2*hashTagsOnSide > 0 { + leftHashTags += "#" + } + return fmt.Sprintf("%s%s%s", leftHashTags, header, rightHashTags) +} + +// MustContinuePrompt prompts the user if they want to continue, and returns an error otherwise. +// promptui requires the ContinueLabel to be one line +func mustContinuePrompt(continueLabel string) error { + if len(continueLabel) == 0 { + continueLabel = "Continue?" + } + if _, result, err := (&promptui.Select{ + Label: continueLabel, + Items: []string{"No", "Yes"}, + }).Run(); err != nil { + return err + } else if result == "No" { + return fmt.Errorf("user aborted") + } + return nil +} diff --git a/cmd/pg-schema-diff/flags.go b/cmd/pg-schema-diff/flags.go new file mode 100644 index 0000000..9dfe478 --- /dev/null +++ b/cmd/pg-schema-diff/flags.go @@ -0,0 +1,42 @@ +package main + +import ( + "os" + + "github.com/jackc/pgx/v4" + "github.com/spf13/cobra" +) + +type connFlags struct { + dsn *string +} + +func createConnFlags(cmd *cobra.Command) connFlags { + schemaDir := cmd.Flags().String("dsn", "", "Connection string for the database (DB password can be specified through PGPASSWORD environment variable)") + mustMarkFlagAsRequired(cmd, "dsn") + + return connFlags{ + dsn: schemaDir, + } +} + +func (c connFlags) parseConnConfig() (*pgx.ConnConfig, error) { + config, err := pgx.ParseConfig(*c.dsn) + if err != nil { + return nil, err + } + + if config.Password == "" { + if pgPassword := os.Getenv("PGPASSWORD"); pgPassword != "" { + config.Password = pgPassword + } + } + + return config, nil +} + +func mustMarkFlagAsRequired(cmd *cobra.Command, flagName string) { + if err := cmd.MarkFlagRequired(flagName); err != nil { + panic(err) + } +} diff --git a/cmd/pg-schema-diff/main.go b/cmd/pg-schema-diff/main.go new file mode 100644 index 0000000..1e92303 --- /dev/null +++ b/cmd/pg-schema-diff/main.go @@ -0,0 +1,25 @@ +package main + +import ( + "os" + + "github.com/spf13/cobra" +) + +// rootCmd represents the base command when called without any subcommands +var rootCmd = &cobra.Command{ + Use: "pg-schema-diff", + Short: "Diff two Postgres schemas and generate the SQL to get from one to the other", +} + +func init() { + rootCmd.AddCommand(buildPlanCmd()) + rootCmd.AddCommand(buildApplyCmd()) +} + +func main() { + err := rootCmd.Execute() + if err != nil { + os.Exit(1) + } +} diff --git a/cmd/pg-schema-diff/plan_cmd.go b/cmd/pg-schema-diff/plan_cmd.go new file mode 100644 index 0000000..5e3511b --- /dev/null +++ b/cmd/pg-schema-diff/plan_cmd.go @@ -0,0 +1,331 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/v4" + "github.com/spf13/cobra" + "github.com/stripe/pg-schema-diff/pkg/diff" + "github.com/stripe/pg-schema-diff/pkg/log" + "github.com/stripe/pg-schema-diff/pkg/tempdb" +) + +var ( + // Match arguments in the format "regex=duration" where duration is any duration valid in time.ParseDuration + // We'll let time.ParseDuration handle the complexity of parsing invalid duration, so the regex we're extracting is + // all characters greedily up to the rightmost "=" + statementTimeoutModifierRegex = regexp.MustCompile(`^(?P.+)=(?P.+)$`) + regexSTMRegexIndex = statementTimeoutModifierRegex.SubexpIndex("regex") + durationSTMRegexIndex = statementTimeoutModifierRegex.SubexpIndex("duration") + + // Match arguments in the format "index duration:statement" where duration is any duration valid in + // time.ParseDuration. In order to prevent matching on ":" in the duration, limit the character to just letters + // and numbers. To keep the regex simple, we won't bother matching on a more specific pattern for durations. + // time.ParseDuration can handle the complexity of parsing invalid durations + insertStatementRegex = regexp.MustCompile(`^(?P\d+) (?P[a-zA-Z0-9\.]+):(?P.+?);?$`) + indexInsertStatementRegexIndex = insertStatementRegex.SubexpIndex("index") + durationInsertStatementRegexIndex = insertStatementRegex.SubexpIndex("duration") + ddlInsertStatementRegexIndex = insertStatementRegex.SubexpIndex("ddl") +) + +func buildPlanCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "plan", + Aliases: []string{"diff"}, + Short: "Generate the diff between two databases and the SQL to get from one to the other", + } + + connFlags := createConnFlags(cmd) + planFlags := createPlanFlags(cmd) + cmd.RunE = func(cmd *cobra.Command, args []string) error { + connConfig, err := connFlags.parseConnConfig() + if err != nil { + return err + } + + planConfig, err := planFlags.parsePlanConfig() + if err != nil { + return err + } + + cmd.SilenceUsage = true + + plan, err := generatePlan(context.Background(), log.SimpleLogger(), connConfig, planConfig) + if err != nil { + return err + } else if len(plan.Statements) == 0 { + fmt.Println("Schema matches expected. No plan generated") + return nil + } + fmt.Printf("\n%s\n", header("Generated plan")) + fmt.Println(planToPrettyS(plan)) + return nil + } + + return cmd +} + +type ( + planFlags struct { + schemaDir *string + statementTimeoutModifiers *[]string + insertStatements *[]string + } + + statementTimeoutModifier struct { + regex *regexp.Regexp + timeout time.Duration + } + + insertStatement struct { + ddl string + index int + timeout time.Duration + } + + planConfig struct { + schemaDir string + statementTimeoutModifiers []statementTimeoutModifier + insertStatements []insertStatement + } +) + +func createPlanFlags(cmd *cobra.Command) planFlags { + schemaDir := cmd.Flags().String("schema-dir", "", "Directory containing schema files") + mustMarkFlagAsRequired(cmd, "schema-dir") + + statementTimeoutModifiers := cmd.Flags().StringArrayP("statement-timeout-modifier", "t", nil, + "regex=timeout key-value pairs, where if a statement matches the regex, the statement will have the target"+ + " timeout. If multiple regexes match, the latest regex will take priority. Example: -t 'CREATE TABLE=5m' -t 'CONCURRENTLY=10s'") + insertStatements := cmd.Flags().StringArrayP("insert-statement", "s", nil, + "_: values. Will insert the statement at the index in the "+ + "generated plan with the specified timeout. This follows normal insert semantics. Example: -s '0 5s:SELECT 1''") + + return planFlags{ + schemaDir: schemaDir, + statementTimeoutModifiers: statementTimeoutModifiers, + insertStatements: insertStatements, + } +} + +func (p planFlags) parsePlanConfig() (planConfig, error) { + var statementTimeoutModifiers []statementTimeoutModifier + for _, s := range *p.statementTimeoutModifiers { + stm, err := parseStatementTimeoutModifierStr(s) + if err != nil { + return planConfig{}, fmt.Errorf("parsing statement timeout modifier from %q: %w", s, err) + } + statementTimeoutModifiers = append(statementTimeoutModifiers, stm) + } + + var insertStatements []insertStatement + for _, i := range *p.insertStatements { + is, err := parseInsertStatementStr(i) + if err != nil { + return planConfig{}, fmt.Errorf("parsing insert statement from %q: %w", i, err) + } + insertStatements = append(insertStatements, is) + } + + return planConfig{ + schemaDir: *p.schemaDir, + statementTimeoutModifiers: statementTimeoutModifiers, + insertStatements: insertStatements, + }, nil +} + +func parseStatementTimeoutModifierStr(val string) (statementTimeoutModifier, error) { + submatches := statementTimeoutModifierRegex.FindStringSubmatch(val) + if len(submatches) <= regexSTMRegexIndex || len(submatches) <= durationSTMRegexIndex { + return statementTimeoutModifier{}, fmt.Errorf("could not parse regex and duration from arg. expected to be in the format of " + + "'Some.*Regex='. Example durations include: 2s, 5m, 10.5h") + } + regexStr := submatches[regexSTMRegexIndex] + durationStr := submatches[durationSTMRegexIndex] + + regex, err := regexp.Compile(regexStr) + if err != nil { + return statementTimeoutModifier{}, fmt.Errorf("regex could not be compiled from %q: %w", regexStr, err) + } + + duration, err := time.ParseDuration(durationStr) + if err != nil { + return statementTimeoutModifier{}, fmt.Errorf("duration could not be parsed from %q: %w", durationStr, err) + } + + return statementTimeoutModifier{ + regex: regex, + timeout: duration, + }, nil +} + +func parseInsertStatementStr(val string) (insertStatement, error) { + submatches := insertStatementRegex.FindStringSubmatch(val) + if len(submatches) <= indexInsertStatementRegexIndex || + len(submatches) <= durationInsertStatementRegexIndex || + len(submatches) <= ddlInsertStatementRegexIndex { + return insertStatement{}, fmt.Errorf("could not parse index, duration, and statement from arg. expected to be in the " + + "format of ' :'. Example durations include: 2s, 5m, 10.5h") + } + indexStr := submatches[indexInsertStatementRegexIndex] + index, err := strconv.Atoi(indexStr) + if err != nil { + return insertStatement{}, fmt.Errorf("could not parse index (an int) from \"%q\"", indexStr) + } + + durationStr := submatches[durationInsertStatementRegexIndex] + duration, err := time.ParseDuration(durationStr) + if err != nil { + return insertStatement{}, fmt.Errorf("duration could not be parsed from \"%q\": %w", durationStr, err) + } + + return insertStatement{ + index: index, + ddl: submatches[ddlInsertStatementRegexIndex], + timeout: duration, + }, nil +} + +func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnConfig, planConfig planConfig) (diff.Plan, error) { + ddl, err := getDDLFromPath(planConfig.schemaDir) + if err != nil { + return diff.Plan{}, nil + } + + tempDbFactory, err := tempdb.NewOnInstanceFactory(ctx, func(ctx context.Context, dbName string) (*sql.DB, error) { + copiedConfig := connConfig.Copy() + copiedConfig.Database = dbName + return openDbWithPgxConfig(copiedConfig) + }) + if err != nil { + return diff.Plan{}, err + } + defer func() { + err := tempDbFactory.Close() + if err != nil { + logger.Errorf("error shutting down temp db factory: %v", err) + } + }() + + connPool, err := openDbWithPgxConfig(connConfig) + if err != nil { + return diff.Plan{}, err + } + defer connPool.Close() + + conn, err := connPool.Conn(ctx) + if err != nil { + return diff.Plan{}, err + } + defer conn.Close() + + plan, err := diff.GeneratePlan(ctx, conn, tempDbFactory, ddl, + diff.WithDataPackNewTables(), + diff.WithIgnoreChangesToColOrder(), + ) + if err != nil { + return diff.Plan{}, fmt.Errorf("generating plan: %w", err) + } + + modifiedPlan, err := applyPlanModifiers(plan, planConfig.statementTimeoutModifiers, planConfig.insertStatements) + if err != nil { + return diff.Plan{}, fmt.Errorf("applying plan modifiers: %w", err) + } + + return modifiedPlan, nil +} + +func applyPlanModifiers( + plan diff.Plan, + statementTimeoutModifiers []statementTimeoutModifier, + insertStatements []insertStatement, +) (diff.Plan, error) { + for _, stm := range statementTimeoutModifiers { + plan = plan.ApplyStatementTimeoutModifier(stm.regex, stm.timeout) + } + for _, is := range insertStatements { + var err error + plan, err = plan.InsertStatement(is.index, diff.Statement{ + DDL: is.ddl, + Timeout: is.timeout, + Hazards: []diff.MigrationHazard{{ + Type: diff.MigrationHazardTypeIsUserGenerated, + Message: "This statement is user-generated", + }}, + }) + if err != nil { + return diff.Plan{}, fmt.Errorf("inserting statement %q with timeout %s at index %d: %w", + is.ddl, is.timeout, is.index, err) + } + } + return plan, nil +} + +func getDDLFromPath(path string) ([]string, error) { + fileEntries, err := os.ReadDir(path) + if err != nil { + return nil, err + } + var ddl []string + for _, entry := range fileEntries { + if filepath.Ext(entry.Name()) == ".sql" { + if stmts, err := os.ReadFile(filepath.Join(path, entry.Name())); err != nil { + return nil, err + } else { + ddl = append(ddl, string(stmts)) + } + } + } + return ddl, nil +} + +func planToPrettyS(plan diff.Plan) string { + sb := strings.Builder{} + + // We are going to put a statement index before each statement. To do that, + // we need to find how many characters are in the largest index, so we can provide the appropriate amount + // of padding before the statements to align all of them + // E.g. + // 1. ALTER TABLE foobar ADD COLUMN foo BIGINT + // .... + // 22. ADD INDEX some_idx ON some_other_table(some_column) + stmtNumPadding := len(strconv.Itoa(len(plan.Statements))) // find how much padding is required for the statement index + fmtString := fmt.Sprintf("%%0%dd. %%s", stmtNumPadding) // supply custom padding + + var stmtStrs []string + for i, stmt := range plan.Statements { + stmtStr := fmt.Sprintf(fmtString, getDisplayableStmtIdx(i), statementToPrettyS(stmt)) + stmtStrs = append(stmtStrs, stmtStr) + } + sb.WriteString(strings.Join(stmtStrs, "\n\n")) + + return sb.String() +} + +func statementToPrettyS(stmt diff.Statement) string { + sb := strings.Builder{} + sb.WriteString(fmt.Sprintf("%s;", stmt.DDL)) + sb.WriteString(fmt.Sprintf("\n\t-- Timeout: %s", stmt.Timeout)) + if len(stmt.Hazards) > 0 { + for _, hazard := range stmt.Hazards { + sb.WriteString(fmt.Sprintf("\n\t-- Hazard %s", hazardToPrettyS(hazard))) + } + } + return sb.String() +} + +func hazardToPrettyS(hazard diff.MigrationHazard) string { + if len(hazard.Message) > 0 { + return fmt.Sprintf("%s: %s", hazard.Type, hazard.Message) + } else { + return hazard.Type + } +} diff --git a/cmd/pg-schema-diff/plan_cmd_test.go b/cmd/pg-schema-diff/plan_cmd_test.go new file mode 100644 index 0000000..3908bfe --- /dev/null +++ b/cmd/pg-schema-diff/plan_cmd_test.go @@ -0,0 +1,119 @@ +package main + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseStatementTimeoutModifierStr(t *testing.T) { + for _, tc := range []struct { + opt string `explicit:"always"` + + expectedRegexStr string + expectedTimeout time.Duration + expectedErrContains string + }{ + { + opt: "normal duration=5m", + expectedRegexStr: "normal duration", + expectedTimeout: 5 * time.Minute, + }, + { + opt: "some regex with a duration ending in a period=5.h", + expectedRegexStr: "some regex with a duration ending in a period", + expectedTimeout: 5 * time.Hour, + }, + { + opt: " starts with spaces than has a *=5.5m", + expectedRegexStr: " starts with spaces than has a *", + expectedTimeout: time.Minute*5 + 30*time.Second, + }, + { + opt: "has a valid opt in the regex something=5.5m in the regex =15s", + expectedRegexStr: "has a valid opt in the regex something=5.5m in the regex ", + expectedTimeout: 15 * time.Second, + }, + { + opt: "has multiple valid opts opt=15m5s in the regex something=5.5m in the regex and has compound duration=15m1ms2us10ns", + expectedRegexStr: "has multiple valid opts opt=15m5s in the regex something=5.5m in the regex and has compound duration", + expectedTimeout: 15*time.Minute + 1*time.Millisecond + 2*time.Microsecond + 10*time.Nanosecond, + }, + { + opt: "=5m", + expectedErrContains: "could not parse regex and duration from arg", + }, + { + opt: "15m", + expectedErrContains: "could not parse regex and duration from arg", + }, + { + opt: "someregex;15m", + expectedErrContains: "could not parse regex and duration from arg", + }, + { + opt: "someregex=invalid duration5s", + expectedErrContains: "duration could not be parsed", + }, + } { + t.Run(tc.opt, func(t *testing.T) { + modifier, err := parseStatementTimeoutModifierStr(tc.opt) + if len(tc.expectedErrContains) > 0 { + assert.ErrorContains(t, err, tc.expectedErrContains) + return + } + require.NoError(t, err) + assert.Equal(t, tc.expectedRegexStr, modifier.regex.String()) + assert.Equal(t, tc.expectedTimeout, modifier.timeout) + }) + } +} + +func TestParseInsertStatementStr(t *testing.T) { + for _, tc := range []struct { + opt string `explicit:"always"` + expectedInsertStmt insertStatement + expectedErrContains string + }{ + { + opt: "1 0h5.1m:SELECT * FROM :TABLE:0_5m:something", + expectedInsertStmt: insertStatement{ + index: 1, + ddl: "SELECT * FROM :TABLE:0_5m:something", + timeout: 5*time.Minute + 6*time.Second, + }, + }, + { + opt: "0 100ms:SELECT 1; SELECT * FROM something;", + expectedInsertStmt: insertStatement{ + index: 0, + ddl: "SELECT 1; SELECT * FROM something", + timeout: 100 * time.Millisecond, + }, + }, + { + opt: " 5s:No index", + expectedErrContains: "could not parse", + }, + { + opt: "0 5g:Invalid duration", + expectedErrContains: "duration could not be parsed", + }, + { + opt: "0 5s:", + expectedErrContains: "could not parse", + }, + } { + t.Run(tc.opt, func(t *testing.T) { + insertStatement, err := parseInsertStatementStr(tc.opt) + if len(tc.expectedErrContains) > 0 { + assert.ErrorContains(t, err, tc.expectedErrContains) + return + } + require.NoError(t, err) + assert.Equal(t, tc.expectedInsertStmt, insertStatement) + }) + } +} diff --git a/cmd/pg-schema-diff/sql.go b/cmd/pg-schema-diff/sql.go new file mode 100644 index 0000000..7d13533 --- /dev/null +++ b/cmd/pg-schema-diff/sql.go @@ -0,0 +1,18 @@ +package main + +import ( + "database/sql" + + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/stdlib" +) + +// openDbWithPgxConfig opens a database connection using the provided pgx.ConnConfig and pings it +func openDbWithPgxConfig(config *pgx.ConnConfig) (*sql.DB, error) { + connPool := stdlib.OpenDB(*config) + if err := connPool.Ping(); err != nil { + connPool.Close() + return nil, err + } + return connPool, nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..373915d --- /dev/null +++ b/go.mod @@ -0,0 +1,35 @@ +module github.com/stripe/pg-schema-diff + +go 1.18 + +require ( + github.com/google/go-cmp v0.5.9 + github.com/google/uuid v1.3.0 + github.com/jackc/pgx/v4 v4.14.0 + github.com/kr/pretty v0.3.1 + github.com/mitchellh/hashstructure/v2 v2.0.2 + github.com/stretchr/testify v1.8.2 +) + +require ( + github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgconn v1.14.0 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgproto3/v2 v2.3.2 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgtype v1.14.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/manifoldco/promptui v0.9.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect + github.com/spf13/cobra v1.7.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/crypto v0.6.0 // indirect + golang.org/x/sys v0.5.0 // indirect + golang.org/x/text v0.7.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..635fea5 --- /dev/null +++ b/go.sum @@ -0,0 +1,234 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.14.0 h1:vrbA9Ud87g6JdFWkHTJXppVce58qPIdP7N8y0Ml/A7Q= +github.com/jackc/pgconn v1.14.0/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0= +github.com/jackc/pgproto3/v2 v2.3.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.9.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw= +github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/pgx/v4 v4.14.0 h1:TgdrmgnM7VY72EuSQzBbBd4JA1RLqJolrw9nQVZABVc= +github.com/jackc/pgx/v4 v4.14.0/go.mod h1:jT3ibf/A0ZVCp89rtCIN0zCJxcE74ypROmHEZYsG/j8= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= +github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= +github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/internal/graph/graph.go b/internal/graph/graph.go new file mode 100644 index 0000000..4d26620 --- /dev/null +++ b/internal/graph/graph.go @@ -0,0 +1,215 @@ +package graph + +import ( + "fmt" + "sort" +) + +type Vertex interface { + GetId() string +} + +type AdjacencyMatrix map[string]map[string]bool + +// Graph is a directed graph +type Graph[V Vertex] struct { + verticesById map[string]V + edges AdjacencyMatrix +} + +func NewGraph[V Vertex]() *Graph[V] { + return &Graph[V]{ + verticesById: make(map[string]V), + edges: make(AdjacencyMatrix), + } +} + +// AddVertex adds a vertex to the graph. +// If the vertex already exists, it will override it and keep the edges +func (g *Graph[V]) AddVertex(v V) { + g.verticesById[v.GetId()] = v + if g.edges[v.GetId()] == nil { + g.edges[v.GetId()] = make(map[string]bool) + } +} + +// AddEdge adds an edge to the graph. If the vertex doesn't exist, it will error +func (g *Graph[V]) AddEdge(sourceId, targetId string) error { + if !g.HasVertexWithId(sourceId) { + return fmt.Errorf("source %s does not exist", sourceId) + } + if !g.HasVertexWithId(targetId) { + return fmt.Errorf("target %s does not exist", targetId) + } + g.edges[sourceId][targetId] = true + + return nil +} + +// Union unions the graph with a new graph. If a vertex exists in both graphs, +// it uses the merge function to determine what the new vertex is +func (g *Graph[V]) Union(new *Graph[V], merge func(old, new V) V) error { + for _, newV := range new.verticesById { + if g.HasVertexWithId(newV.GetId()) { + id := newV.GetId() + // merge the vertices using the procedure defined by the user + newV = merge(g.GetVertex(newV.GetId()), newV) + if newV.GetId() != id { + return fmt.Errorf("the merge function must return a vertex with the same id: "+ + "expected %s but found %s", id, newV.GetId()) + } + } + g.AddVertex(newV) + } + + for source, adjacentEdgesMap := range new.edges { + for target, isAdjacent := range adjacentEdgesMap { + if isAdjacent { + if err := g.AddEdge(source, target); err != nil { + return fmt.Errorf("adding an edge from the new graph: %w", err) + } + } + } + } + + return nil +} + +func (g *Graph[V]) GetVertex(id string) V { + return g.verticesById[id] +} + +func (g *Graph[V]) HasVertexWithId(id string) bool { + _, hasVertex := g.verticesById[id] + return hasVertex +} + +// Reverse reverses the edges of the map. The sources become the sinks and vice versa. +func (g *Graph[V]) Reverse() { + reversedEdges := make(AdjacencyMatrix) + for vertexId, _ := range g.verticesById { + reversedEdges[vertexId] = make(map[string]bool) + } + for source, adjacentEdgesMap := range g.edges { + for target, isAdjacent := range adjacentEdgesMap { + if isAdjacent { + reversedEdges[target][source] = true + } + } + } + g.edges = reversedEdges +} + +func (g *Graph[V]) Copy() *Graph[V] { + verticesById := make(map[string]V) + for id, v := range g.verticesById { + verticesById[id] = v + } + + edges := make(AdjacencyMatrix) + for source, adjacentEdgesMap := range g.edges { + edges[source] = make(map[string]bool) + for target, isAdjacent := range adjacentEdgesMap { + edges[source][target] = isAdjacent + } + } + + return &Graph[V]{ + verticesById: verticesById, + edges: edges, + } +} + +func (g *Graph[V]) TopologicallySort() ([]V, error) { + return g.TopologicallySortWithPriority(func(_, _ V) bool { + return false + }) +} + +type Ordered interface { + ~int | ~string +} + +func IsLowerPriorityFromGetPriority[V Vertex, P Ordered](getPriority func(V) P) func(V, V) bool { + return func(v1 V, v2 V) bool { + return getPriority(v1) < getPriority(v2) + } +} + +// TopologicallySortWithPriority returns a consistent topological sort of the graph taking a greedy approach to put +// high priority sources first. The output is deterministic. getPriority must be deterministic +func (g *Graph[V]) TopologicallySortWithPriority(isLowerPriority func(V, V) bool) ([]V, error) { + // This uses mutation. Copy the graph + graph := g.Copy() + + // The strategy of this algorithm: + // 1. Count the number of incoming edges to each vertex + // 2. Remove the sources and add them to the outputIds + // 3. Decrement the number of incoming edges to each vertex adjacent to the sink + // 4. Repeat 2-3 until the graph is empty + + // The number of outgoing in the reversed graph is the number of incoming in the original graph + // In other words, a source in the graph (has no incoming edges) is a sink in the reversed graph + // (has no outgoing edges) + + // To find the number of incoming edges in graph, just count the number of outgoing edges + // in the reversed graph + reversedGraph := graph.Copy() + reversedGraph.Reverse() + incomingEdgeCountByVertex := make(map[string]int) + for vertex, reversedAdjacentEdges := range reversedGraph.edges { + count := 0 + for _, isAdjacent := range reversedAdjacentEdges { + if isAdjacent { + count++ + } + } + incomingEdgeCountByVertex[vertex] = count + } + + var output []V + // Add the sinks to the output. Delete the sinks. Repeat. + for len(graph.verticesById) > 0 { + // Put all the sources into an array, so we can get a stable sort of them before identifying the one + // with the highest priority + var sources []V + for sourceId, incomingEdgeCount := range incomingEdgeCountByVertex { + if incomingEdgeCount == 0 { + sources = append(sources, g.GetVertex(sourceId)) + } + } + sort.Slice(sources, func(i, j int) bool { + return sources[i].GetId() < sources[j].GetId() + }) + + // Take the source with highest priority from the sorted array of sources + indexOfSourceWithHighestPri := -1 + for i, source := range sources { + if indexOfSourceWithHighestPri == -1 || isLowerPriority(sources[indexOfSourceWithHighestPri], source) { + indexOfSourceWithHighestPri = i + } + } + if indexOfSourceWithHighestPri == -1 { + return nil, fmt.Errorf("cycle detected: %+v, %+v", graph, incomingEdgeCountByVertex) + } + sourceWithHighestPriority := sources[indexOfSourceWithHighestPri] + + output = append(output, sourceWithHighestPriority) + + // Remove source vertex from graph and update counts + // Update incoming edge counts + for target, hasEdge := range graph.edges[sourceWithHighestPriority.GetId()] { + if hasEdge { + incomingEdgeCountByVertex[target]-- + } + } + + // Delete the vertex from graph + delete(graph.verticesById, sourceWithHighestPriority.GetId()) + // We don't need to worry about any incoming edges referencing this vertex, since it was a sink + delete(graph.edges, sourceWithHighestPriority.GetId()) + delete(incomingEdgeCountByVertex, sourceWithHighestPriority.GetId()) + } + + return output, nil +} diff --git a/internal/graph/graph_test.go b/internal/graph/graph_test.go new file mode 100644 index 0000000..de0a010 --- /dev/null +++ b/internal/graph/graph_test.go @@ -0,0 +1,432 @@ +package graph + +import ( + "fmt" + "strconv" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddGetHasOperations(t *testing.T) { + g := NewGraph[vertex]() + + // get missing vertex return Zero value + assert.Zero(t, g.GetVertex("missing_vertex")) + + // add vertex + v1 := NewV("v_1") + g.AddVertex(v1) + assert.True(t, g.HasVertexWithId("v_1")) + assert.Equal(t, v1, g.GetVertex("v_1")) + + // override vertex + newV1 := NewV("v_1") + g.AddVertex(newV1) + assert.True(t, g.HasVertexWithId("v_1")) + assert.Equal(t, newV1, g.GetVertex("v_1")) + assert.NotEqual(t, v1, g.GetVertex("v_1")) + + // add edge when target is missing + assert.Error(t, g.AddEdge("v_1", "missing_vertex")) + + // add edge when source is missing + assert.Error(t, g.AddEdge("missing_vertex", "v_1")) + + // add edge when both are present + v2 := NewV("v_2") + g.AddVertex(v2) + assert.True(t, g.HasVertexWithId("v_2")) + assert.NoError(t, g.AddEdge("v_1", "v_2")) + assert.Equal(t, AdjacencyMatrix{ + "v_1": { + "v_2": true, + }, + "v_2": {}, + }, g.edges) + + // overriding vertex keeps edges + g.AddVertex(NewV("v_1")) + g.AddVertex(NewV("v_2")) + assert.Equal(t, AdjacencyMatrix{ + "v_1": { + "v_2": true, + }, + "v_2": {}, + }, g.edges) + + // override edge + assert.NoError(t, g.AddEdge("v_1", "v_2")) + assert.Equal(t, AdjacencyMatrix{ + "v_1": { + "v_2": true, + }, + "v_2": {}, + }, g.edges) + + // allow cycles + assert.NoError(t, g.AddEdge("v_2", "v_1")) + assert.Equal(t, AdjacencyMatrix{ + "v_1": { + "v_2": true, + }, + "v_2": { + "v_1": true, + }, + }, g.edges) +} + +func TestReverse(t *testing.T) { + g := NewGraph[vertex]() + g.AddVertex(NewV("v_1")) + g.AddVertex(NewV("v_2")) + g.AddVertex(NewV("v_3")) + assert.NoError(t, g.AddEdge("v_1", "v_2")) + assert.NoError(t, g.AddEdge("v_1", "v_3")) + assert.NoError(t, g.AddEdge("v_2", "v_3")) + g.Reverse() + assert.Equal(t, AdjacencyMatrix{ + "v_1": {}, + "v_2": { + "v_1": true, + }, + "v_3": { + "v_1": true, + "v_2": true, + }, + }, g.edges) + assert.ElementsMatch(t, getVertexIds(g), []string{"v_1", "v_2", "v_3"}) +} + +func TestCopy(t *testing.T) { + g := NewGraph[vertex]() + g.AddVertex(NewV("shared_1")) + g.AddVertex(NewV("shared_2")) + g.AddVertex(NewV("shared_3")) + assert.NoError(t, g.AddEdge("shared_1", "shared_3")) + assert.NoError(t, g.AddEdge("shared_3", "shared_2")) + + gC := g.Copy() + gC.AddVertex(NewV("g_copy_1")) + gC.AddVertex(NewV("g_copy_2")) + assert.NoError(t, gC.AddEdge("g_copy_1", "g_copy_2")) + assert.NoError(t, gC.AddEdge("shared_3", "g_copy_1")) + assert.NoError(t, gC.AddEdge("shared_3", "shared_1")) + copyOverrideShared1 := NewV("shared_1") + gC.AddVertex(copyOverrideShared1) + + g.AddVertex(NewV("g_1")) + g.AddVertex(NewV("g_2")) + g.AddVertex(NewV("g_3")) + assert.NoError(t, g.AddEdge("g_3", "g_1")) + assert.NoError(t, g.AddEdge("g_2", "shared_2")) + assert.NoError(t, g.AddEdge("shared_2", "shared_3")) + originalOverrideShared2 := NewV("shared_2") + g.AddVertex(originalOverrideShared2) + + // validate nodes on copy + assert.ElementsMatch(t, getVertexIds(gC), []string{ + "shared_1", "shared_2", "shared_3", "g_copy_1", "g_copy_2", + }) + + // validate nodes on original + assert.ElementsMatch(t, getVertexIds(g), []string{ + "shared_1", "shared_2", "shared_3", "g_1", "g_2", "g_3", + }) + + // validate overrides weren't copied over and non-overriden shared nodes are the same + assert.NotEqual(t, g.GetVertex("shared_1"), gC.GetVertex("shared_1")) + assert.Equal(t, gC.GetVertex("shared_1"), copyOverrideShared1) + assert.NotEqual(t, g.GetVertex("shared_2"), gC.GetVertex("shared_2")) + assert.Equal(t, g.GetVertex("shared_2"), originalOverrideShared2) + assert.Equal(t, g.GetVertex("shared_3"), gC.GetVertex("shared_3")) + + // validate edges + assert.Equal(t, AdjacencyMatrix{ + "shared_1": { + "shared_3": true, + }, + "shared_2": {}, + "shared_3": { + "shared_1": true, + "shared_2": true, + "g_copy_1": true, + }, + "g_copy_1": { + "g_copy_2": true, + }, + "g_copy_2": {}, + }, gC.edges) + assert.Equal(t, AdjacencyMatrix{ + "shared_1": { + "shared_3": true, + }, + "shared_2": { + "shared_3": true, + }, + "shared_3": { + "shared_2": true, + }, + "g_1": {}, + "g_2": { + "shared_2": true, + }, + "g_3": { + "g_1": true, + }, + }, g.edges) +} + +func TestUnion(t *testing.T) { + gA := NewGraph[vertex]() + gA1 := NewV("a_1") + gA.AddVertex(gA1) + gA2 := NewV("a_2") + gA.AddVertex(gA2) + gA3 := NewV("a_3") + gA.AddVertex(gA3) + gAShared1 := NewV("shared_1") + gA.AddVertex(gAShared1) + gAShared2 := NewV("shared_2") + gA.AddVertex(gAShared2) + assert.NoError(t, gA.AddEdge("a_1", "a_2")) + assert.NoError(t, gA.AddEdge("a_3", "a_1")) + assert.NoError(t, gA.AddEdge("shared_1", "a_1")) + + gB := NewGraph[vertex]() + gB1 := NewV("b_1") + gB.AddVertex(gB1) + gB2 := NewV("b_2") + gB.AddVertex(gB2) + gB3 := NewV("b_3") + gB.AddVertex(gB3) + gBShared1 := NewV("shared_1") + gB.AddVertex(gBShared1) + gBShared2 := NewV("shared_2") + gB.AddVertex(gBShared2) + assert.NoError(t, gB.AddEdge("b_3", "b_2")) + assert.NoError(t, gB.AddEdge("b_3", "b_1")) + assert.NoError(t, gB.AddEdge("shared_1", "b_2")) + + err := gA.Union(gB, func(old, new vertex) vertex { + return vertex{ + id: old.id, + val: fmt.Sprintf("%s_%s", old.val, new.val), + } + }) + assert.NoError(t, err) + + // make sure non-shared nodes were not merged + assert.Equal(t, gA1, gA.GetVertex("a_1")) + assert.Equal(t, gA2, gA.GetVertex("a_2")) + assert.Equal(t, gA3, gA.GetVertex("a_3")) + assert.Equal(t, gB1, gA.GetVertex("b_1")) + assert.Equal(t, gB2, gA.GetVertex("b_2")) + assert.Equal(t, gB3, gA.GetVertex("b_3")) + + // check merged nodes + assert.Equal(t, vertex{ + id: "shared_1", + val: fmt.Sprintf("%s_%s", gAShared1.val, gBShared1.val), + }, gA.GetVertex("shared_1")) + assert.Equal(t, vertex{ + id: "shared_2", + val: fmt.Sprintf("%s_%s", gAShared2.val, gBShared2.val), + }, gA.GetVertex("shared_2")) + + // no extra nodes + assert.ElementsMatch(t, getVertexIds(gA), []string{ + "a_1", "a_2", "a_3", "b_1", "b_2", "b_3", "shared_1", "shared_2", + }) + + assert.Equal(t, AdjacencyMatrix{ + "a_1": { + "a_2": true, + }, + "a_2": {}, + "a_3": { + "a_1": true, + }, + "b_1": {}, + "b_2": {}, + "b_3": { + "b_1": true, + "b_2": true, + }, + "shared_1": { + "a_1": true, + "b_2": true, + }, + "shared_2": {}, + }, gA.edges) +} + +func TestUnionFailsIfMergeReturnsDifferentId(t *testing.T) { + gA := NewGraph[vertex]() + gA.AddVertex(NewV("shared_1")) + gB := NewGraph[vertex]() + gB.AddVertex(NewV("shared_1")) + assert.Error(t, gA.Union(gB, func(old, new vertex) vertex { + return vertex{ + id: "different_id", + val: old.val, + } + })) +} + +func TestTopologicallySort(t *testing.T) { + // Source: https://en.wikipedia.org/wiki/Topological_sorting#Examples + g := NewGraph[vertex]() + v5 := NewV("05") + g.AddVertex(v5) + v7 := NewV("07") + g.AddVertex(v7) + v3 := NewV("03") + g.AddVertex(v3) + v11 := NewV("11") + g.AddVertex(v11) + v8 := NewV("08") + g.AddVertex(v8) + v2 := NewV("02") + g.AddVertex(v2) + v9 := NewV("09") + g.AddVertex(v9) + v10 := NewV("10") + g.AddVertex(v10) + assert.NoError(t, g.AddEdge("05", "11")) + assert.NoError(t, g.AddEdge("07", "11")) + assert.NoError(t, g.AddEdge("07", "08")) + assert.NoError(t, g.AddEdge("03", "08")) + assert.NoError(t, g.AddEdge("03", "10")) + assert.NoError(t, g.AddEdge("11", "02")) + assert.NoError(t, g.AddEdge("11", "09")) + assert.NoError(t, g.AddEdge("11", "10")) + assert.NoError(t, g.AddEdge("08", "09")) + + orderedNodes, err := g.TopologicallySort() + assert.NoError(t, err) + assert.Equal(t, []vertex{ + v3, v5, v7, v8, v11, v2, v9, v10, + }, orderedNodes) + + // Cycle should error + assert.NoError(t, g.AddEdge("10", "07")) + _, err = g.TopologicallySort() + assert.Error(t, err) +} + +func TestTopologicallySortWithPriority(t *testing.T) { + // Source: https://en.wikipedia.org/wiki/Topological_sorting#Examples + g := NewGraph[vertex]() + v5 := NewV("05") + g.AddVertex(v5) + v7 := NewV("07") + g.AddVertex(v7) + v3 := NewV("03") + g.AddVertex(v3) + v11 := NewV("11") + g.AddVertex(v11) + v8 := NewV("08") + g.AddVertex(v8) + v2 := NewV("02") + g.AddVertex(v2) + v9 := NewV("09") + g.AddVertex(v9) + v10 := NewV("10") + g.AddVertex(v10) + assert.NoError(t, g.AddEdge("05", "11")) + assert.NoError(t, g.AddEdge("07", "11")) + assert.NoError(t, g.AddEdge("07", "08")) + assert.NoError(t, g.AddEdge("03", "08")) + assert.NoError(t, g.AddEdge("03", "10")) + assert.NoError(t, g.AddEdge("11", "02")) + assert.NoError(t, g.AddEdge("11", "09")) + assert.NoError(t, g.AddEdge("11", "10")) + assert.NoError(t, g.AddEdge("08", "09")) + + for _, tc := range []struct { + name string + isLowerPriority func(v1, v2 vertex) bool + expectedOrdering []vertex + }{ + { + name: "largest-numbered available vertex first (string-based GetPriority)", + isLowerPriority: IsLowerPriorityFromGetPriority(func(v vertex) string { + return v.GetId() + }), + expectedOrdering: []vertex{v7, v5, v11, v3, v10, v8, v9, v2}, + }, + { + name: "smallest-numbered available vertex first (numeric-based GetPriority)", + isLowerPriority: IsLowerPriorityFromGetPriority(func(v vertex) int { + idAsInt, err := strconv.Atoi(v.GetId()) + require.NoError(t, err) + return -idAsInt + }), + expectedOrdering: []vertex{v3, v5, v7, v8, v11, v2, v9, v10}, + }, + { + name: "fewest edges first (prioritize high id's for tie breakers)", + isLowerPriority: func(v1, v2 vertex) bool { + v1EdgeCount := getEdgeCount(g, v1) + v2EdgeCount := getEdgeCount(g, v2) + if v1EdgeCount == v2EdgeCount { + // Break ties with ID + return v1.GetId() < v2.GetId() + } + return v1EdgeCount > v2EdgeCount + }, + expectedOrdering: []vertex{v5, v7, v3, v8, v11, v10, v9, v2}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + + // Highest number vertices should be prioritized first + orderedNodes, err := g.TopologicallySortWithPriority(tc.isLowerPriority) + assert.NoError(t, err) + assert.Equal(t, tc.expectedOrdering, orderedNodes) + }) + } + + // Cycle should error + assert.NoError(t, g.AddEdge("10", "07")) + _, err := g.TopologicallySort() + assert.Error(t, err) +} + +func getEdgeCount[V Vertex](g *Graph[V], v Vertex) int { + edgeCount := 0 + for _, hasEdge := range g.edges[v.GetId()] { + if hasEdge { + edgeCount++ + } + } + return edgeCount +} + +type vertex struct { + id string + val string +} + +func NewV(id string) vertex { + uuid, err := uuid.NewUUID() + if err != nil { + panic(err) + } + return vertex{id: id, val: uuid.String()} +} + +func (v vertex) GetId() string { + return v.id +} + +func getVertexIds(g *Graph[vertex]) []string { + var output []string + for id, _ := range g.verticesById { + output = append(output, id) + } + return output +} diff --git a/internal/migration_acceptance_tests/acceptance_test.go b/internal/migration_acceptance_tests/acceptance_test.go new file mode 100644 index 0000000..940dab7 --- /dev/null +++ b/internal/migration_acceptance_tests/acceptance_test.go @@ -0,0 +1,206 @@ +package migration_acceptance_tests + +import ( + "context" + "database/sql" + "fmt" + "testing" + + _ "github.com/jackc/pgx/v4/stdlib" + "github.com/kr/pretty" + "github.com/stretchr/testify/suite" + "github.com/stripe/pg-schema-diff/internal/pgdump" + "github.com/stripe/pg-schema-diff/internal/pgengine" + "github.com/stripe/pg-schema-diff/pkg/diff" + + "github.com/stripe/pg-schema-diff/pkg/log" + "github.com/stripe/pg-schema-diff/pkg/tempdb" +) + +type ( + expectations struct { + planErrorIs error + planErrorContains string + // outputState should be the DDL required to reconstruct the expected output state of the database + // + // The outputState might differ from the newSchemaDDL due to options passed to the migrator. For example, + // the data packing option will cause the column ordering for new tables in the outputState to differ from + // the column ordering of those tables defined in newSchemaDDL + // + // If no outputState is specified, the newSchemaDDL will be used + outputState []string + } + + acceptanceTestCase struct { + name string + oldSchemaDDL []string + newSchemaDDL []string + + // expectedHazardTypes should contain all the unique migration hazard types that are expected to be within the + // generated plan + expectedHazardTypes []diff.MigrationHazardType + + // vanillaExpectations refers to the expectations of the migration if no additional opts are used + vanillaExpectations expectations + // dataPackingExpectations refers to the expectations of the migration if table packing and ignore column order are used + dataPackingExpectations expectations + } + + acceptanceTestSuite struct { + suite.Suite + pgEngine *pgengine.Engine + } +) + +func (suite *acceptanceTestSuite) SetupSuite() { + engine, err := pgengine.StartEngine() + suite.Require().NoError(err) + suite.pgEngine = engine +} + +func (suite *acceptanceTestSuite) TearDownSuite() { + suite.pgEngine.Close() +} + +// Simulates migrating a database and uses pgdump to compare the actual state to the expected state +func (suite *acceptanceTestSuite) runTestCases(acceptanceTestCases []acceptanceTestCase) { + for _, tc := range acceptanceTestCases { + suite.Run(tc.name, func() { + suite.Run("vanilla", func() { + suite.runSubtest(tc, tc.vanillaExpectations, nil) + }) + suite.Run("with data packing (and ignoring column order)", func() { + suite.runSubtest(tc, tc.dataPackingExpectations, []diff.PlanOpt{ + diff.WithDataPackNewTables(), + diff.WithIgnoreChangesToColOrder(), + diff.WithLogger(log.SimpleLogger()), + }) + }) + }) + } +} + +func (suite *acceptanceTestSuite) runSubtest(tc acceptanceTestCase, expects expectations, planOpts []diff.PlanOpt) { + // onDbInitQueries will be run on both the old database before the migration and the new database before pg_dump + onDbInitQueries := []string{ + // Enable an extension to enforce that diffing works with extensions enabled + `CREATE EXTENSION amcheck;`, + } + + // normalize the subtest + if expects.outputState == nil { + expects.outputState = tc.newSchemaDDL + } + + // Apply old schema DDL to old DB + oldDb, err := suite.pgEngine.CreateDatabase() + suite.Require().NoError(err) + defer oldDb.DropDB() + // Apply the old schema + suite.Require().NoError(applyDDL(oldDb, append(onDbInitQueries, tc.oldSchemaDDL...))) + + // Migrate the old DB + oldDBConnPool, err := sql.Open("pgx", oldDb.GetDSN()) + suite.Require().NoError(err) + defer oldDBConnPool.Close() + oldDbConn, _ := oldDBConnPool.Conn(context.Background()) + defer oldDbConn.Close() + + tempDbFactory, err := tempdb.NewOnInstanceFactory(context.Background(), func(ctx context.Context, dbName string) (*sql.DB, error) { + return sql.Open("pgx", suite.pgEngine.GetPostgresDatabaseConnOpts().With("dbname", dbName).ToDSN()) + }) + suite.Require().NoError(err) + defer func(tempDbFactory tempdb.Factory) { + // It's important that this closes properly (the temp database is dropped), + // so assert it has no error for acceptance tests + suite.Require().NoError(tempDbFactory.Close()) + }(tempDbFactory) + + plan, err := diff.GeneratePlan(context.Background(), oldDbConn, tempDbFactory, tc.newSchemaDDL, planOpts...) + + if expects.planErrorIs != nil || len(expects.planErrorContains) > 0 { + if expects.planErrorIs != nil { + suite.ErrorIs(err, expects.planErrorIs) + } + if len(expects.planErrorContains) > 0 { + suite.ErrorContains(err, expects.planErrorContains) + } + return + } + suite.Require().NoError(err) + + suite.ElementsMatch(tc.expectedHazardTypes, getUniqueHazardTypesFromStatements(plan.Statements), prettySprintPlan(plan)) + + // Apply the plan + suite.Require().NoError(applyPlan(oldDb, plan), prettySprintPlan(plan)) + + // Make sure the pgdump after running the migration is the same as the + // pgdump from a database where we directly run the newSchemaDDL + oldDbDump, err := pgdump.GetDump(oldDb, pgdump.WithSchemaOnly()) + suite.Require().NoError(err) + + newDbDump := suite.directlyRunDDLAndGetDump(append(onDbInitQueries, expects.outputState...)) + suite.Equal(newDbDump, oldDbDump, prettySprintPlan(plan)) + + // Make sure no diff is found if we try to regenerate a plan + plan, err = diff.GeneratePlan(context.Background(), oldDbConn, tempDbFactory, tc.newSchemaDDL, planOpts...) + suite.Require().NoError(err) + suite.Empty(plan.Statements, prettySprintPlan(plan)) +} + +func (suite *acceptanceTestSuite) directlyRunDDLAndGetDump(ddl []string) string { + newDb, err := suite.pgEngine.CreateDatabase() + suite.Require().NoError(err) + defer newDb.DropDB() + suite.Require().NoError(applyDDL(newDb, ddl)) + + newDbDump, err := pgdump.GetDump(newDb, pgdump.WithSchemaOnly()) + suite.Require().NoError(err) + return newDbDump +} + +func applyDDL(db *pgengine.DB, ddl []string) error { + conn, err := sql.Open("pgx", db.GetDSN()) + if err != nil { + return err + } + defer conn.Close() + + for _, stmt := range ddl { + _, err := conn.Exec(stmt) + if err != nil { + return fmt.Errorf("DDL:\n: %w"+stmt, err) + } + } + return nil +} + +func applyPlan(db *pgengine.DB, plan diff.Plan) error { + var ddl []string + for _, stmt := range plan.Statements { + ddl = append(ddl, stmt.ToSQL()) + } + return applyDDL(db, ddl) +} + +func getUniqueHazardTypesFromStatements(statements []diff.Statement) []diff.MigrationHazardType { + var seenHazardTypes = make(map[diff.MigrationHazardType]bool) + var hazardTypes []diff.MigrationHazardType + for _, stmt := range statements { + for _, hazard := range stmt.Hazards { + if _, hasHazard := seenHazardTypes[hazard.Type]; !hasHazard { + seenHazardTypes[hazard.Type] = true + hazardTypes = append(hazardTypes, hazard.Type) + } + } + } + return hazardTypes +} + +func prettySprintPlan(plan diff.Plan) string { + return fmt.Sprintf("%# v", pretty.Formatter(plan.Statements)) +} + +func TestAcceptanceSuite(t *testing.T) { + suite.Run(t, new(acceptanceTestSuite)) +} diff --git a/internal/migration_acceptance_tests/check_constraint_cases_test.go b/internal/migration_acceptance_tests/check_constraint_cases_test.go new file mode 100644 index 0000000..599e2ba --- /dev/null +++ b/internal/migration_acceptance_tests/check_constraint_cases_test.go @@ -0,0 +1,512 @@ +package migration_acceptance_tests + +import ( + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var checkConstraintCases = []acceptanceTestCase{ + { + name: "No-op", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT CHECK ( bar > id ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT CHECK ( bar > id ) + ); + `, + }, + }, + { + name: "Add check constraint", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT CHECK ( bar > id ) + ); + `, + }, + }, + { + name: "Add check constraint with UDF dependency should error", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT CHECK ( add(bar, id) > 0 ) + ); + `, + }, + vanillaExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + dataPackingExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + }, + { + name: "Add check constraint with system function dependency should not error", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT CHECK ( to_timestamp(id) <= CURRENT_TIMESTAMP ) + ); + `, + }, + }, + { + name: "Add multiple check constraints", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT, + CHECK ( bar > id ), CHECK ( bar IS NOT NULL ), CHECK (bar > 0) + ); + `, + }, + }, + { + name: "Add check constraints to new column", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT CHECK ( bar > id ), CHECK ( bar IS NOT NULL ), CHECK (bar > 0) + ); + `, + }, + }, + { + name: "Add check constraint with quoted identifiers", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + "ID" INT PRIMARY KEY, + foo VARCHAR(255), + "Bar" BIGINT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + "ID" INT PRIMARY KEY, + foo VARCHAR(255), + "Bar" BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT "BAR_CHECK" CHECK ( "Bar" < "ID" ); + `, + }, + }, + { + name: "Add no inherit check constraint", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NO INHERIT; + `, + }, + }, + { + name: "Add No-Inherit, Not-Valid check constraint", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NO INHERIT NOT VALID; + `, + }, + }, + { + name: "Drop check constraint", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT CHECK ( bar > id ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + `, + }, + }, + { + name: "Drop check constraint with quoted identifiers", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + "ID" INT PRIMARY KEY, + foo VARCHAR(255), + "Bar" BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT "BAR_CHECK" CHECK ( "Bar" < "ID" ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + "ID" INT PRIMARY KEY, + foo VARCHAR(255), + "Bar" BIGINT + ); + `, + }, + }, + { + name: "Drop column with check constraints", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT, + CHECK ( bar > id ), CHECK ( bar IS NOT NULL ), CHECK (bar > 0) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeDeletesData}, + }, + { + name: "Drop check constraint with UDF dependency should error", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT CHECK ( add(bar, id) > 0 ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT + ); + `, + }, + vanillaExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + dataPackingExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + }, + { + name: "Drop check constraint with system function dependency should not error", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT CHECK ( to_timestamp(id) <= CURRENT_TIMESTAMP ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT + ); + `, + }, + }, + { + name: "Alter an invalid check constraint to be valid", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NOT VALID; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ); + `, + }, + }, + { + name: "Alter a valid check constraint to be invalid", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NOT VALID; + `, + }, + }, + { + name: "Alter a No-Inherit check constraint to be Inheritable", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NO INHERIT; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ); + `, + }, + }, + { + name: "Alter an Inheritable check constraint to be No-Inherit", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT + ); + ALTER TABLE foobar ADD CONSTRAINT bar_check CHECK ( bar > id ) NO INHERIT; + `, + }, + }, + { + name: "Alter a check constraint expression", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT CHECK (bar > id) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar BIGINT CHECK (bar < id) + ); + `, + }, + }, + { + name: "Alter check constraint with UDF dependency should error", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT + ); + ALTER TABLE foobar ADD CONSTRAINT some_constraint CHECK ( add(bar, id) > 0 ) NOT VALID; + `, + }, + newSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT + ); + ALTER TABLE foobar ADD CONSTRAINT some_constraint CHECK ( add(bar, id) > 0 ); + `, + }, + vanillaExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + dataPackingExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + }, + { + name: "Alter check constraint with system function dependency should not error", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT + ); + ALTER TABLE foobar ADD CONSTRAINT some_constraint CHECK ( to_timestamp(id) <= CURRENT_TIMESTAMP ) NOT VALID; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar INT + ); + ALTER TABLE foobar ADD CONSTRAINT some_constraint CHECK ( to_timestamp(id) <= CURRENT_TIMESTAMP ); + `, + }, + }, +} + +func (suite *acceptanceTestSuite) TestCheckConstraintAcceptanceTestCases() { + suite.runTestCases(checkConstraintCases) +} diff --git a/internal/migration_acceptance_tests/column_cases_test.go b/internal/migration_acceptance_tests/column_cases_test.go new file mode 100644 index 0000000..9ef60c9 --- /dev/null +++ b/internal/migration_acceptance_tests/column_cases_test.go @@ -0,0 +1,663 @@ +package migration_acceptance_tests + +import ( + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var columnAcceptanceTestCases = []acceptanceTestCase{ + { + name: "No-op", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) COLLATE "C" DEFAULT '' NOT NULL, + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) COLLATE "C" DEFAULT '' NOT NULL, + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL + ); + `, + }, + }, + { + name: "Add one column with default", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + my_new_column VARCHAR(255) DEFAULT 'a' + ); + `, + }, + }, + { + name: "Add one column with quoted names", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY, + "My_new_column" VARCHAR(255) DEFAULT 'a' + ); + `, + }, + }, + { + name: "Add one column with nullability", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + my_new_column VARCHAR(255) NOT NULL + ); + `, + }, + }, + { + name: "Add one column with all options", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + my_new_column VARCHAR(255) COLLATE "C" NOT NULL DEFAULT 'a' + ); + `, + }, + }, + { + name: "Add one column and change ordering", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + my_new_column VARCHAR(255) NOT NULL DEFAULT 'a', + id INT PRIMARY KEY + ); + `, + }, + vanillaExpectations: expectations{ + planErrorIs: diff.ErrColumnOrderingChanged, + }, + dataPackingExpectations: expectations{ + outputState: []string{` + CREATE TABLE foobar( + id INT PRIMARY KEY, + my_new_column VARCHAR(255) NOT NULL DEFAULT 'a' + ) + `}, + }, + }, + { + name: "Delete one column", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeDeletesData, + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Delete one column with quoted name", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + "Id" INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeDeletesData, + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Modify data type (varchar -> char)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar CHAR NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Modify data type (varchar -> TEXT) with compatible default", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) DEFAULT 'some default' NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar TEXT DEFAULT 'some default' NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Modify data type and collation (varchar -> char)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) COLLATE "C" NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar CHAR COLLATE "POSIX" NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Modify data type to incompatible (bytea -> char)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar bytea NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar CHAR NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Modify collation (default -> non-default)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) COLLATE "C" NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) COLLATE "POSIX" NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Modify collation (non-default -> non-default)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) COLLATE "POSIX" NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Add Default", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) DEFAULT '' + ); + `, + }, + }, + { + name: "Remove Default", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) DEFAULT '' + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) + ); + `, + }, + }, + { + name: "Change Default", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) DEFAULT 'Something else' + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) DEFAULT '' + ); + `, + }, + }, + { + name: "Set NOT NULL", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + }, + }, + { + name: "Remove NOT NULL", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) + ); + `, + }, + }, + { + name: "Add default and change data type (new default is incompatible with old type)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar INT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) DEFAULT 'SOMETHING' + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Change default and data type (new default is incompatible with old type)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar INT DEFAULT 0 + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar VARCHAR(255) DEFAULT 'SOMETHING' + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Change default and data type (old default is incompatible with new type)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar TEXT DEFAULT 'SOMETHING' + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar INT DEFAULT 8 + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + vanillaExpectations: expectations{planErrorContains: "validating migration plan"}, + dataPackingExpectations: expectations{planErrorContains: "validating migration plan"}, + }, + { + name: "Change to not null", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar INT DEFAULT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar INT NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + }, + }, + { + name: "Change from NULL default to no default and NOT NULL", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar INT DEFAULT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar INT NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + }, + }, + { + name: "Change from NOT NULL to no NULL default", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar INT NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar INT DEFAULT NULL + ); + `, + }, + }, + { + name: "Change data type and to nullable", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar INT NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar SMALLINT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Change data type and to not nullable", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar INT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar SMALLINT NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Change data type, nullability (NOT NULL), and default", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar TEXT DEFAULT 'some default' + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foobar CHAR NOT NULL DEFAULT 'A' + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Change data type and collation, nullability (NOT NULL), and default with quoted names", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY, + "Foobar" TEXT DEFAULT 'some default' + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY, + "Foobar" CHAR COLLATE "POSIX" NOT NULL DEFAULT 'A' + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Change BIGINT to TIMESTAMP, nullability (NOT NULL), and default with current_timestamp", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY, + some_time_col BIGINT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY, + some_time_col TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, +} + +func (suite *acceptanceTestSuite) TestColumnAcceptanceTestCases() { + suite.runTestCases(columnAcceptanceTestCases) +} diff --git a/internal/migration_acceptance_tests/function_cases_test.go b/internal/migration_acceptance_tests/function_cases_test.go new file mode 100644 index 0000000..86a683c --- /dev/null +++ b/internal/migration_acceptance_tests/function_cases_test.go @@ -0,0 +1,480 @@ +package migration_acceptance_tests + +import ( + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var functionAcceptanceTestCases = []acceptanceTestCase{ + { + name: "No-op", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + `, + }, + newSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + `, + }, + }, + { + name: "Create functions (with conflicting names)", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + CREATE FUNCTION add(a text, b text) RETURNS text + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(a, b); + `, + }, + }, + { + name: "Create functions with quoted names (with conflicting names)", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE FUNCTION "some add"(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + CREATE FUNCTION "some add"(a text, b text) RETURNS text + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(a, b); + `, + }, + }, + { + name: "Create non-sql function", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE FUNCTION non_sql_func(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + `}, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Create function with dependencies", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + "increment func"(a); + + -- function with conflicting name to ensure the deps specify param name + CREATE FUNCTION add(a text, b text) RETURNS text + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(a, b); + `}, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Drop functions (with conflicting names)", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + CREATE FUNCTION add(a text, b text) RETURNS text + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(a, b); + `, + }, + newSchemaDDL: nil, + }, + { + name: "Drop functions with quoted names (with conflicting names)", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION "some add"(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + CREATE FUNCTION "some add"(a text, b text) RETURNS text + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(a, b); + `, + }, + newSchemaDDL: nil, + }, + { + name: "Drop non-sql function", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION non_sql_func(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + `}, + newSchemaDDL: nil, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Drop function with dependencies", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + "increment func"(a); + + -- function with conflicting name to ensure the deps specify param name + CREATE FUNCTION add(a text, b text) RETURNS text + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(a, b); + `, + }, + newSchemaDDL: nil, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Alter functions (with conflicting names)", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + CREATE FUNCTION add(a TEXT, b TEXT) RETURNS TEXT + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(a, b); + `, + }, + newSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + a + b; + CREATE FUNCTION add(a TEXT, b TEXT) RETURNS TEXT + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(CONCAT(a, a), b); + `, + }, + }, + { + name: "Alter functions with quoted names (with conflicting names)", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION "some add"(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + CREATE FUNCTION "some add"(a TEXT, b TEXT) RETURNS TEXT + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(a, b); + `, + }, + newSchemaDDL: []string{ + ` + CREATE FUNCTION "some add"(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + a + b; + CREATE FUNCTION "some add"(a TEXT, b TEXT) RETURNS TEXT + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(CONCAT(a, a), b); + `, + }, + }, + { + name: "Alter non-sql function", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION non_sql_func(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + `}, + newSchemaDDL: []string{ + ` + CREATE FUNCTION non_sql_func(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 5; + END; + $$ LANGUAGE plpgsql; + `}, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Alter sql function to be non-sql function", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION some_func(i integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURN i + 5; + `}, + newSchemaDDL: []string{ + ` + CREATE FUNCTION some_func(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + `}, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Alter non-sql function to be sql function (no dependency tracking error)", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION some_func(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + `}, + newSchemaDDL: []string{ + ` + CREATE FUNCTION some_func(i integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURN i + 5; + `}, + }, + { + name: "Alter a function's dependencies", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + "increment func"(a); + + -- function with conflicting name to ensure the deps specify param name + CREATE FUNCTION add(a text, b text) RETURNS text + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(a, b); + `, + }, + newSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION "decrement func"(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + "decrement func"(a); + + -- function with conflicting name to ensure the deps specify param name + CREATE FUNCTION add(a text, b text) RETURNS text + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN CONCAT(a, b); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Alter a dependent function", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + "increment func"(a); + `, + }, + newSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION "increment func"(i integer) RETURNS int AS $$ + BEGIN + RETURN i + 5; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + "increment func"(a); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Alter a function to no longer depend on a function and drop that function", + oldSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION "increment func"(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + "increment func"(a); + `, + }, + newSchemaDDL: []string{ + ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, +} + +func (suite *acceptanceTestSuite) TestFunctionAcceptanceTestCases() { + suite.runTestCases(functionAcceptanceTestCases) +} diff --git a/internal/migration_acceptance_tests/index_cases_test.go b/internal/migration_acceptance_tests/index_cases_test.go new file mode 100644 index 0000000..a52d682 --- /dev/null +++ b/internal/migration_acceptance_tests/index_cases_test.go @@ -0,0 +1,553 @@ +package migration_acceptance_tests + +import ( + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var indexAcceptanceTestCases = []acceptanceTestCase{ + { + name: "No-op", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar TEXT, + fizz INT + ); + CREATE INDEX some_idx ON foobar USING hash (foo); + CREATE UNIQUE INDEX some_other_idx ON foobar (bar DESC, fizz); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255), + bar TEXT, + fizz INT + ); + CREATE INDEX some_idx ON foobar USING hash (foo); + CREATE UNIQUE INDEX some_other_idx ON foobar (bar DESC, fizz); + `, + }, + }, + { + name: "Add a normal index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) + ); + CREATE INDEX some_idx ON foobar(id DESC, foo); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a hash index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) + ); + CREATE INDEX some_idx ON foobar USING hash (id); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a normal index with quoted names", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY, + "Foo" VARCHAR(255) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY, + "Foo" VARCHAR(255) + ); + CREATE INDEX "Some_idx" ON "Foobar"(id, "Foo"); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a unique index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) NOT NULL + ); + CREATE UNIQUE INDEX some_unique_idx ON foobar(foo); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a primary key", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a primary key on NOT NULL column", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT NOT NULL PRIMARY KEY + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a primary key when the index already exists", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT + ); + CREATE UNIQUE INDEX foobar_primary_key ON foobar(id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT + ); + CREATE UNIQUE INDEX foobar_primary_key ON foobar(id); + ALTER TABLE foobar ADD CONSTRAINT foobar_primary_key PRIMARY KEY USING INDEX foobar_primary_key; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + }, + }, + { + name: "Add a primary key when the index already exists but has a name different to the constraint", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT + ); + CREATE UNIQUE INDEX foobar_idx ON foobar(id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT + ); + CREATE UNIQUE INDEX foobar_idx ON foobar(id); + -- This renames the index + ALTER TABLE foobar ADD CONSTRAINT foobar_primary_key PRIMARY KEY USING INDEX foobar_idx; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a primary key when the index already exists but is not unique", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT + ); + CREATE INDEX foobar_idx ON foobar(id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT + ); + CREATE UNIQUE INDEX foobar_primary_key ON foobar(id); + ALTER TABLE foobar ADD CONSTRAINT foobar_primary_key PRIMARY KEY USING INDEX foobar_primary_key; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Delete a normal index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) NOT NULL + ); + CREATE INDEX some_inx ON foobar(id, foo); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Delete a normal index with quoted names", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY, + "Foo" VARCHAR(255) + ); + CREATE INDEX "Some_idx" ON "Foobar"(id, "Foo"); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY, + "Foo" VARCHAR(255) NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Delete a unique index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) NOT NULL + ); + CREATE UNIQUE INDEX some_unique_idx ON foobar(foo); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Delete a primary key", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Change an index (with a really long name) columns", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo TEXT NOT NULL, + bar BIGINT NOT NULL + ); + CREATE UNIQUE INDEX some_idx_with_a_really_long_name_that_is_nearly_61_chars ON foobar(foo, bar) + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo TEXT NOT NULL, + bar BIGINT NOT NULL + ); + CREATE UNIQUE INDEX some_idx_with_a_really_long_name_that_is_nearly_61_chars ON foobar(foo) + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Change an index type", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo TEXT NOT NULL, + bar BIGINT NOT NULL + ); + CREATE INDEX some_idx ON foobar (foo) + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo TEXT NOT NULL, + bar BIGINT NOT NULL + ); + CREATE INDEX some_idx ON foobar USING hash (foo) + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Change an index column ordering", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo TEXT NOT NULL, + bar BIGINT NOT NULL + ); + CREATE INDEX some_idx ON foobar (foo, bar) + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo TEXT NOT NULL, + bar BIGINT NOT NULL + ); + CREATE INDEX some_idx ON foobar (foo DESC, bar) + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Delete columns and associated index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo TEXT NOT NULL, + bar BIGINT NOT NULL + ); + CREATE UNIQUE INDEX some_idx ON foobar(foo, bar); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeDeletesData, + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Switch primary key and make old key nullable", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT NOT NULL PRIMARY KEY, + foo TEXT NOT NULL + ); + CREATE UNIQUE INDEX some_idx ON foobar(foo); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo INT NOT NULL PRIMARY KEY + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Switch primary key with quoted name", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + "Id" INT NOT NULL PRIMARY KEY, + foo TEXT NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + "Id" INT, + foo INT NOT NULL PRIMARY KEY + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Switch primary key when the original primary key constraint has a non-default name", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT NOT NULL, + foo TEXT NOT NULL + ); + CREATE UNIQUE INDEX unique_idx ON foobar(id); + ALTER TABLE foobar ADD CONSTRAINT non_default_primary_key PRIMARY KEY USING INDEX unique_idx; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo INT NOT NULL PRIMARY KEY + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Alter primary key columns (name stays same)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT NOT NULL, + foo TEXT NOT NULL + ); + CREATE UNIQUE INDEX unique_idx ON foobar(id); + ALTER TABLE foobar ADD CONSTRAINT non_default_primary_key PRIMARY KEY USING INDEX unique_idx; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo INT NOT NULL + ); + CREATE UNIQUE INDEX unique_idx ON foobar(id, foo); + ALTER TABLE foobar ADD CONSTRAINT non_default_primary_key PRIMARY KEY USING INDEX unique_idx; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, +} + +func (suite *acceptanceTestSuite) TestIndexAcceptanceTestCases() { + suite.runTestCases(indexAcceptanceTestCases) +} diff --git a/internal/migration_acceptance_tests/local_partition_index_cases_test.go b/internal/migration_acceptance_tests/local_partition_index_cases_test.go new file mode 100644 index 0000000..cd96d2b --- /dev/null +++ b/internal/migration_acceptance_tests/local_partition_index_cases_test.go @@ -0,0 +1,442 @@ +package migration_acceptance_tests + +import ( + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var localPartitionIndexAcceptanceTestCases = []acceptanceTestCase{ + { + name: "No-op", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT, + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX foobar_1_some_idx ON foobar_1 (foo); + CREATE UNIQUE INDEX foobar_2_some_unique_idx ON foobar_2 (foo); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT, + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX foobar_1_some_idx ON foobar_1 (foo); + CREATE UNIQUE INDEX foobar_2_some_unique_idx ON foobar_2 (foo); + `, + }, + }, + { + name: "Add local indexes", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz BYTEA + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz BYTEA + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX foobar_1_some_idx ON foobar_1(foo, id); + CREATE INDEX foobar_2_some_idx ON foobar_2(foo, bar); + CREATE INDEX foobar_3_some_idx ON foobar_3(foo, fizz); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a unique local index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE UNIQUE INDEX foobar_1_some_idx ON foobar_1 (foo); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add local primary keys", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz bytea + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz bytea + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar( + CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id) + ) FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar( + CONSTRAINT "foobar2_PRIMARY_KEY" PRIMARY KEY (foo, bar) + ) FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar( + CONSTRAINT "foobar3_PRIMARY_KEY" PRIMARY KEY (foo, fizz) + ) FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Delete a local index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz BYTEA + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX foobar_1_some_idx ON foobar_1(foo, id); + CREATE INDEX foobar_2_some_idx ON foobar_2(foo, bar); + CREATE INDEX foobar_3_some_idx ON foobar_3(foo, fizz); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz BYTEA + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX foobar_1_some_idx ON foobar_1(foo, id); + CREATE INDEX foobar_3_some_idx ON foobar_3(foo, fizz); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Delete a unique local index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE UNIQUE INDEX some_unique_idx ON foobar_1(foo, id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Delete a primary key", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz bytea + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar( + CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id) + ) FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar( + CONSTRAINT "foobar2_PRIMARY_KEY" PRIMARY KEY (foo, bar) + ) FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar( + CONSTRAINT "foobar3_PRIMARY_KEY" PRIMARY KEY (foo, fizz) + ) FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz bytea + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar( + CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id) + ) FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar( + CONSTRAINT "foobar2_PRIMARY_KEY" PRIMARY KEY (foo, bar) + ) FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Change an index columns", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE UNIQUE INDEX some_unique_idx ON foobar_1(foo, id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE UNIQUE INDEX some_unique_idx ON foobar_1(id, foo); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Delete columns and associated index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE UNIQUE INDEX some_unique_idx ON foobar_1(foo, id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + diff.MigrationHazardTypeIndexDropped, + }, + vanillaExpectations: expectations{ + planErrorIs: diff.ErrColumnOrderingChanged, + }, + }, + { + name: "Switch primary key", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz bytea + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar( + CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id) + ) FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar( + CONSTRAINT "foobar2_PRIMARY_KEY" PRIMARY KEY (foo, bar) + ) FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar( + CONSTRAINT "foobar3_PRIMARY_KEY" PRIMARY KEY (foo, fizz) + ) FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz bytea + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar( + CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id) + ) FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar( + CONSTRAINT "foobar2_PRIMARY_KEY" PRIMARY KEY (foo, fizz) + ) FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar( + CONSTRAINT "foobar3_PRIMARY_KEY" PRIMARY KEY (foo, fizz) + ) FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a primary key when the index already exists", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) NOT NULL, + bar TEXT NOT NULL, + fizz bytea + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE UNIQUE INDEX "foobar1_PRIMARY_KEY" ON foobar_1(foo, id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz bytea + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar( + CONSTRAINT "foobar1_PRIMARY_KEY" PRIMARY KEY (foo, id) + ) FOR VALUES IN ('foo_1'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + }, + }, + { + name: "Switch partitioned index to local index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz BYTEA + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + + CREATE INDEX foobar_some_idx ON foobar(foo, id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz BYTEA + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + + CREATE INDEX foobar_1_foo_id_idx ON foobar_1(foo, id); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, +} + +func (suite *acceptanceTestSuite) TestLocalPartitionIndexAcceptanceTestCases() { + suite.runTestCases(localPartitionIndexAcceptanceTestCases) +} diff --git a/internal/migration_acceptance_tests/partitioned_index_cases_test.go b/internal/migration_acceptance_tests/partitioned_index_cases_test.go new file mode 100644 index 0000000..e56438b --- /dev/null +++ b/internal/migration_acceptance_tests/partitioned_index_cases_test.go @@ -0,0 +1,627 @@ +package migration_acceptance_tests + +import ( + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var partitionedIndexAcceptanceTestCases = []acceptanceTestCase{ + { + name: "No-op", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT, + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + CREATE INDEX some_idx ON foobar USING hash (foo); + CREATE UNIQUE INDEX some_other_idx ON foobar(foo DESC, fizz); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT, + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + CREATE INDEX some_idx ON foobar USING hash (foo); + CREATE UNIQUE INDEX some_other_idx ON foobar(foo DESC, fizz); + `, + }, + }, + { + name: "Add a normal partitioned index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX some_idx ON foobar(id DESC, foo); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a hash index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX some_idx ON foobar USING hash (foo); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a unique partitioned index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + CREATE UNIQUE INDEX some_unique_idx ON foobar(foo, id); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a normal partitioned index with quotes names", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT, + "Foo" VARCHAR(255) + ) PARTITION BY LIST ("Foo"); + CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar" FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT, + "Foo" VARCHAR(255) + ) PARTITION BY LIST ("Foo"); + CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar" FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3'); + + CREATE INDEX "SOME_IDX" ON "Foobar"(id, "Foo"); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a primary key", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeAcquiresShareLock, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a partitioned index that is used by a local primary key", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE UNIQUE INDEX some_idx ON ONLY foobar(foo, id); + CREATE UNIQUE INDEX foobar_1_pkey ON foobar_1(foo, id); + ALTER TABLE foobar_1 ADD CONSTRAINT foobar_1_pkey PRIMARY KEY USING INDEX foobar_1_pkey; + ALTER INDEX some_idx ATTACH PARTITION foobar_1_pkey; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a primary key with quoted names", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + "Id" INT, + "FOO" VARCHAR(255) + ) PARTITION BY LIST ("FOO"); + CREATE TABLE foobar_1 PARTITION OF "Foobar" FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + "Id" INT, + "FOO" VARCHAR(255) + ) PARTITION BY LIST ("FOO"); + CREATE TABLE foobar_1 PARTITION OF "Foobar" FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3'); + ALTER TABLE "Foobar" ADD CONSTRAINT "FOOBAR_PK" PRIMARY KEY("FOO", "Id") + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexBuild, + diff.MigrationHazardTypeAcquiresShareLock, + }, + }, + { + name: "Add a partitioned primary key when the local index already exists", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE UNIQUE INDEX foobar_1_unique_idx ON foobar_1(foo, id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + ALTER TABLE foobar ADD CONSTRAINT foobar_pkey PRIMARY KEY (foo, id); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeAcquiresShareLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Add a unique index when the local index already exists", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE UNIQUE INDEX foobar_1_foo_id_idx ON foobar_1(foo, id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + ALTER TABLE foobar ADD CONSTRAINT foobar_pkey PRIMARY KEY (foo, id); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE UNIQUE INDEX foobar_unique ON foobar(foo, id); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeAcquiresShareLock, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Delete a normal partitioned index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + CREATE INDEX some_idx ON foobar(foo, id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Delete a unique partitioned index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + CREATE UNIQUE INDEX some_unique_idx ON foobar(foo, id); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Delete a primary key", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + }, + }, + { + name: "Change an index columns", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar INT + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX some_idx ON foobar(id, foo); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar INT + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX some_idx ON foobar(foo, bar); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Change an index type", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar INT + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX some_idx ON foobar(foo); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar INT + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX some_idx ON foobar USING hash (foo); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Change an index column ordering", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar INT + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX some_idx ON foobar (foo, bar); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar INT + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX some_idx ON foobar (bar, foo); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Delete columns and associated index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + CREATE INDEX some_idx ON foobar(id, foo); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + foo VARCHAR(255) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeDeletesData, + diff.MigrationHazardTypeIndexDropped, + }, + vanillaExpectations: expectations{ + planErrorIs: diff.ErrColumnOrderingChanged, + }, + }, + { + name: "Switch primary key", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + PRIMARY KEY (foo) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresShareLock, + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Attach an unnattached, invalid index", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT, + foo VARCHAR(255), + PRIMARY KEY (foo) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF "Foobar" FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2'); + CREATE TABLE "Foobar_3" PARTITION OF "Foobar" FOR VALUES IN ('foo_3'); + + CREATE INDEX "Partitioned_Idx" ON ONLY "Foobar"(foo); + + CREATE INDEX "foobar_1_part" ON foobar_1(foo); + ALTER INDEX "Partitioned_Idx" ATTACH PARTITION "foobar_1_part"; + + CREATE INDEX "foobar_2_part" ON foobar_2(foo); + ALTER INDEX "Partitioned_Idx" ATTACH PARTITION "foobar_2_part"; + + CREATE INDEX "Foobar_3_Part" ON "Foobar_3"(foo); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT, + foo VARCHAR(255), + PRIMARY KEY (foo) + ) PARTITION BY LIST (foo); + CREATE TABLE foobar_1 PARTITION OF "Foobar" FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2'); + CREATE TABLE "Foobar_3" PARTITION OF "Foobar" FOR VALUES IN ('foo_3'); + + CREATE INDEX "Partitioned_Idx" ON ONLY "Foobar"(foo); + + CREATE INDEX "foobar_1_part" ON foobar_1(foo); + ALTER INDEX "Partitioned_Idx" ATTACH PARTITION "foobar_1_part"; + + CREATE INDEX "foobar_2_part" ON foobar_2(foo); + ALTER INDEX "Partitioned_Idx" ATTACH PARTITION "foobar_2_part"; + + CREATE INDEX "Foobar_3_Part" ON "Foobar_3"(foo); + ALTER INDEX "Partitioned_Idx" ATTACH PARTITION "Foobar_3_Part"; + `, + }, + }, +} + +func (suite *acceptanceTestSuite) TestPartitionedIndexAcceptanceTestCases() { + suite.runTestCases(partitionedIndexAcceptanceTestCases) +} diff --git a/internal/migration_acceptance_tests/partitioned_table_cases_test.go b/internal/migration_acceptance_tests/partitioned_table_cases_test.go new file mode 100644 index 0000000..d764db0 --- /dev/null +++ b/internal/migration_acceptance_tests/partitioned_table_cases_test.go @@ -0,0 +1,810 @@ +package migration_acceptance_tests + +import ( + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ + { + name: "No-op", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT COLLATE "POSIX", + fizz INT, + CHECK ( fizz > 0 ), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + -- partitioned indexes + CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, fizz); + + -- local indexes + CREATE INDEX foobar_1_local_idx ON foobar_1(foo); + CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT COLLATE "POSIX", + fizz INT, + CHECK ( fizz > 0 ), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + -- partitioned indexes + CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, fizz); + + -- local indexes + CREATE INDEX foobar_1_local_idx ON foobar_1(foo); + CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo); + `, + }, + }, + { + name: "Create partitioned table with shared primary key", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT, + foo VARCHAR(255), + bar TEXT COLLATE "POSIX" NOT NULL DEFAULT 'some default', + fizz INT, + CHECK ( fizz > 0 ), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + + CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar"( + foo NOT NULL, + bar NOT NULL + ) FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3'); + + -- partitioned indexes + CREATE UNIQUE INDEX foobar_unique_idx ON "Foobar"(foo, fizz); + + -- local indexes + CREATE INDEX foobar_1_local_idx ON "FOOBAR_1"(foo); + CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo); + `, + }, + dataPackingExpectations: expectations{ + outputState: []string{ + ` + CREATE TABLE "Foobar"( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT COLLATE "POSIX" NOT NULL DEFAULT 'some default', + CHECK ( fizz > 0 ), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + + CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar"( + foo NOT NULL, + bar NOT NULL + ) FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3'); + + -- partitioned indexes + CREATE UNIQUE INDEX foobar_unique_idx ON "Foobar"(foo, fizz); + + -- local indexes + CREATE INDEX foobar_1_local_idx ON "FOOBAR_1"(foo); + CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo); + `, + }, + }, + }, + { + name: "Create partitioned table with local primary keys", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT, + CHECK ( fizz > 0 ) + ) PARTITION BY LIST (foo); + CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar"( + foo NOT NULL, + bar NOT NULL, + PRIMARY KEY (foo, id) + ) FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF "Foobar"( + PRIMARY KEY (foo, bar) + ) FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF "Foobar"( + PRIMARY KEY (foo, fizz) + ) FOR VALUES IN ('foo_3'); + + -- partitioned indexes + CREATE UNIQUE INDEX foobar_unique_idx ON "Foobar"(foo, fizz); + + -- local indexes + CREATE INDEX foobar_1_local_idx ON "FOOBAR_1"(foo); + CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo); + `, + }, + dataPackingExpectations: expectations{ + outputState: []string{ + ` + CREATE TABLE "Foobar"( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT, + CHECK ( fizz > 0 ) + ) PARTITION BY LIST (foo); + + CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar"( + foo NOT NULL, + bar NOT NULL, + PRIMARY KEY (foo, id) + ) FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF "Foobar"( + PRIMARY KEY (foo, bar) + ) FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF "Foobar"( + PRIMARY KEY (foo, fizz) + ) FOR VALUES IN ('foo_3'); + + -- partitioned indexes + CREATE UNIQUE INDEX foobar_unique_idx ON "Foobar"(foo, fizz); + + -- local indexes + CREATE INDEX foobar_1_local_idx ON "FOOBAR_1"(foo); + CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo); + `, + }, + }, + }, + { + name: "Drop table", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT, + CHECK ( fizz > 0 ), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + + CREATE TABLE "FOOBAR_1" PARTITION OF "Foobar"( + foo NOT NULL, + bar NOT NULL + ) FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF "Foobar" FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF "Foobar" FOR VALUES IN ('foo_3'); + + -- partitioned indexes + CREATE UNIQUE INDEX foobar_unique_idx ON "Foobar"(foo, fizz); + + -- local indexes + CREATE INDEX foobar_1_local_idx ON "FOOBAR_1"(foo); + CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + newSchemaDDL: nil, + }, + { + name: "Alter table: New primary key, change column types, delete unique partitioned index index, new partitioned index, delete local index, add local index, validate check constraint", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT COLLATE "C", + fizz INT, + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + + ALTER TABLE foobar ADD CONSTRAINT some_check_constraint CHECK ( fizz > 0 ) NOT VALID; + + CREATE TABLE foobar_1 PARTITION OF foobar( + foo NOT NULL, + bar NOT NULL + ) FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + -- partitioned indexes + CREATE INDEX foobar_some_idx ON foobar(foo, bar); + + -- local indexes + CREATE INDEX foobar_1_local_idx ON foobar_1(foo); + CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id TEXT, + foo VARCHAR(255), + bar TEXT COLLATE "POSIX", + fizz INT, + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + ALTER TABLE foobar ADD CONSTRAINT some_check_constraint CHECK ( fizz > 0 ); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar( + bar NOT NULL + ) FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + + -- partitioned indexes + CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, fizz); + + -- local indexes + CREATE UNIQUE INDEX foobar_2_local_unique_idx ON foobar_2(foo); + CREATE INDEX foobar_3_local_idx ON foobar_3(foo, bar); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Changing partition key def errors", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT + ) PARTITION BY LIST (bar); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + `, + }, + vanillaExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + dataPackingExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + }, + { + name: "Unpartitioned to partitioned", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + version INT, + foo VARCHAR(255), + bar TEXT + ); + + CREATE INDEX some_idx on foobar(id); + + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foobar + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + version INT, + foo VARCHAR(255), + bar TEXT + ) PARTITION BY LIST (foo); + + CREATE INDEX some_idx on foobar(id); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foobar + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + }, + { + name: "Unpartitioned to partitioned and child tables already exist", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + version INT, + foo VARCHAR(255), + bar TEXT + ); + + CREATE INDEX some_idx on foobar(id); + + CREATE TABLE foobar_1( + id INT, + version INT, + foo VARCHAR(255), + bar TEXT + ); + + CREATE INDEX foobar_1_id_idx on foobar(id); + + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foobar + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foobar_1 + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + version INT, + foo VARCHAR(255), + bar TEXT + ) PARTITION BY LIST (foo); + + CREATE INDEX some_idx on foobar(id); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foobar + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + }, + { + name: "Partitioned to unpartitioned", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + version INT, + foo VARCHAR(255), + bar TEXT + ) PARTITION BY LIST (foo); + + CREATE INDEX some_idx on foobar(id); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foobar + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + version INT, + foo VARCHAR(255), + bar TEXT + ); + + CREATE INDEX some_idx on foobar(id); + + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foobar + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + }, + { + name: "Partitioned to unpartitioned and child tables still exist", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + version INT, + foo VARCHAR(255), + bar TEXT + ) PARTITION BY LIST (foo); + + CREATE INDEX some_idx on foobar(id); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foobar + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + version INT, + foo VARCHAR(255), + bar TEXT + ); + + CREATE INDEX some_idx on foobar(id); + + CREATE TABLE foobar_1( + id INT, + version INT, + foo VARCHAR(255), + bar TEXT + ); + + CREATE INDEX foobar_1_id_idx on foobar(id); + + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foobar + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foobar_1 + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + }, + { + name: "Adding a partition", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT, + CHECK ( fizz > 0 ), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + + -- partitioned indexes + CREATE UNIQUE INDEX some_partitioned_idx ON foobar(foo, bar); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT, + CHECK ( fizz > 0 ), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + + -- partitioned indexes + CREATE UNIQUE INDEX some_partitioned_idx ON foobar(foo, bar); + `, + }, + }, + { + name: "Adding a partition with local primary key that can back the unique index", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT, + CHECK ( fizz > 0 ) + ) PARTITION BY LIST (foo); + + -- partitioned indexes + CREATE UNIQUE INDEX some_partitioned_idx ON foobar(foo, bar); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT, + CHECK ( fizz > 0 ) + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 PARTITION OF foobar( + PRIMARY KEY (foo, bar) + ) FOR VALUES IN ('foo_1'); + + -- partitioned indexes + CREATE UNIQUE INDEX some_partitioned_idx ON foobar(foo, bar); + `, + }, + }, + { + name: "Deleting a partitioning errors", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT + ) PARTITION BY LIST (foo); + `, + }, + vanillaExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + dataPackingExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + }, + { + name: "Altering a partition's 'FOR VALUES' errors", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + bar TEXT, + fizz INT + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_2'); + `, + }, + vanillaExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + dataPackingExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + }, + { + name: "Re-creating base table causes partitions to be re-created", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 PARTITION OF foobar( + PRIMARY KEY (foo, fizz) + ) FOR VALUES IN ('foo_1'); + + -- partitioned indexes + CREATE INDEX some_partitioned_idx ON foobar(foo, bar); + -- local indexes + CREATE INDEX some_local_idx ON foobar_1(foo, bar); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar_new( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 PARTITION OF foobar_new( + PRIMARY KEY (foo, fizz) + ) FOR VALUES IN ('foo_1'); + + -- partitioned indexes + CREATE INDEX some_partitioned_idx ON foobar_new(foo, bar); + -- local indexes + CREATE INDEX some_local_idx ON foobar_1(foo, bar); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + }, + { + name: "Can handle scenario where partition is not attached", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 ( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT + ); + + -- partitioned indexes + CREATE INDEX some_partitioned_idx ON foobar(foo, bar); + -- local indexes + CREATE INDEX some_local_idx ON foobar_1(foo, bar); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT + ) PARTITION BY LIST (foo); + + CREATE TABLE foobar_1 ( + id INT, + fizz INT, + foo VARCHAR(255), + bar TEXT + ); + ALTER TABLE foobar ATTACH PARTITION foobar_1 FOR VALUES IN ('foo_1'); + + -- partitioned indexes + CREATE INDEX some_partitioned_idx ON foobar(foo, bar); + -- local indexes + CREATE INDEX some_local_idx ON foobar_1(foo, bar); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + }, +} + +func (suite *acceptanceTestSuite) TestPartitionedTableAcceptanceTestCases() { + suite.runTestCases(partitionedTableAcceptanceTestCases) +} diff --git a/internal/migration_acceptance_tests/schema_cases_test.go b/internal/migration_acceptance_tests/schema_cases_test.go new file mode 100644 index 0000000..24e0b39 --- /dev/null +++ b/internal/migration_acceptance_tests/schema_cases_test.go @@ -0,0 +1,446 @@ +package migration_acceptance_tests + +import ( + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +// These are tests for "public" schema" alterations (full migrations) +var schemaAcceptanceTests = []acceptanceTestCase{ + { + name: "No-op", + oldSchemaDDL: []string{ + ` + CREATE TABLE fizz( + ); + + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + increment(a); + + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0), + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL, + PRIMARY KEY (foo, id) + ) PARTITION BY LIST(foo); + + CREATE TABLE foobar_1 PARTITION of foobar( + fizz NOT NULL + ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2'); + + -- partitioned indexes + CREATE INDEX foobar_normal_idx ON foobar(foo DESC, fizz); + CREATE INDEX foobar_hash_idx ON foobar USING hash (foo); + CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar); + -- local indexes + CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz); + + CREATE table bar( + id VARCHAR(255) PRIMARY KEY, + foo INT DEFAULT 0 CHECK (foo > 0 AND foo > bar), + bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, + fizz timestamptz DEFAULT CURRENT_TIMESTAMP, + buzz REAL NOT NULL CHECK (buzz IS NOT NULL) + ); + CREATE INDEX bar_normal_idx ON bar(bar); + CREATE INDEX bar_another_normal_id ON bar(bar, fizz); + CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE fizz( + ); + + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + increment(a); + + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0), + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL, + PRIMARY KEY (foo, id) + ) PARTITION BY LIST(foo); + + CREATE TABLE foobar_1 PARTITION of foobar( + fizz NOT NULL + ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2'); + + -- partitioned indexes + CREATE INDEX foobar_normal_idx ON foobar(foo DESC, fizz); + CREATE INDEX foobar_hash_idx ON foobar USING hash (foo); + CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar); + -- local indexes + CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz); + + CREATE table bar( + id VARCHAR(255) PRIMARY KEY, + foo INT DEFAULT 0 CHECK (foo > 0 AND foo > bar), + bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, + fizz timestamptz DEFAULT CURRENT_TIMESTAMP, + buzz REAL NOT NULL CHECK (buzz IS NOT NULL) + ); + CREATE INDEX bar_normal_idx ON bar(bar); + CREATE INDEX bar_another_normal_id ON bar(bar, fizz); + CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + `, + }, + }, + { + name: "Drop table, Add Table, Drop Funcs, Add Funcs, Drop Triggers, Add Triggers", + oldSchemaDDL: []string{ + ` + CREATE TABLE fizz( + ); + + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + increment(a); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = increment(OLD.version); + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TABLE foobar( + id INT PRIMARY KEY, + foo VARCHAR(255) DEFAULT 'some default' NOT NULL, + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL + ); + CREATE INDEX foobar_normal_idx ON foobar USING hash (fizz); + CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar DESC); + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON foobar + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + + CREATE table bar( + id VARCHAR(255) PRIMARY KEY, + foo INT DEFAULT 0, + bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, + fizz timestamptz DEFAULT CURRENT_TIMESTAMP, + buzz REAL NOT NULL + ); + ALTER TABLE bar ADD CONSTRAINT "FOO_CHECK" CHECK (foo < bar) NOT VALID; + CREATE INDEX bar_normal_idx ON bar(bar); + CREATE INDEX bar_another_normal_id ON bar(bar DESC, fizz DESC); + CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE fizz( + ); + + CREATE FUNCTION "new add"(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b + a; + + CREATE FUNCTION "new increment"(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 2; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION "new function with dependencies"(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN "new add"(a, b) + "new increment"(a); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = "new increment"(OLD.version); + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TABLE "New_table"( + id INT PRIMARY KEY, + new_foo VARCHAR(255) DEFAULT '' NOT NULL CHECK ( new_foo IS NOT NULL), + new_bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + version INT NOT NULL DEFAULT 0 + ); + ALTER TABLE "New_table" ADD CONSTRAINT "new_bar_check" CHECK ( new_bar < CURRENT_TIMESTAMP - interval '1 month' ) NO INHERIT NOT VALID; + CREATE UNIQUE INDEX foobar_unique_idx ON "New_table"(new_foo, new_bar); + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "New_table" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + + CREATE TABLE bar( + id VARCHAR(255) PRIMARY KEY, + foo INT, + bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, + fizz timestamptz DEFAULT CURRENT_TIMESTAMP, + buzz REAL NOT NULL + ); + ALTER TABLE bar ADD CONSTRAINT "FOO_CHECK" CHECK ( foo < bar ); + CREATE INDEX bar_normal_idx ON bar(bar); + CREATE INDEX bar_another_normal_id ON bar(bar DESC, fizz DESC); + CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + + CREATE FUNCTION check_content() RETURNS TRIGGER AS $$ + BEGIN + IF LENGTH(NEW.id) == 0 THEN + RAISE EXCEPTION 'content is empty'; + END IF; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_check_trigger + BEFORE UPDATE ON bar + FOR EACH ROW + EXECUTE FUNCTION check_content(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + diff.MigrationHazardTypeHasUntrackableDependencies, + }, + dataPackingExpectations: expectations{ + outputState: []string{ + ` + CREATE TABLE fizz( + ); + + CREATE FUNCTION "new add"(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b + a; + + CREATE FUNCTION "new increment"(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 2; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION "new function with dependencies"(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN "new add"(a, b) + "new increment"(a); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = "new increment"(OLD.version); + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TABLE "New_table"( + new_bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + id INT PRIMARY KEY, + version INT NOT NULL DEFAULT 0, + new_foo VARCHAR(255) DEFAULT '' NOT NULL CHECK (new_foo IS NOT NULL) + ); + ALTER TABLE "New_table" ADD CONSTRAINT "new_bar_check" CHECK ( new_bar < CURRENT_TIMESTAMP - interval '1 month' ) NO INHERIT NOT VALID; + CREATE UNIQUE INDEX foobar_unique_idx ON "New_table"(new_foo, new_bar); + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "New_table" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + + CREATE TABLE bar( + id VARCHAR(255) PRIMARY KEY, + foo INT, + bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, + fizz timestamptz DEFAULT CURRENT_TIMESTAMP, + buzz REAL NOT NULL + ); + ALTER TABLE bar ADD CONSTRAINT "FOO_CHECK" CHECK ( foo < bar ); + CREATE INDEX bar_normal_idx ON bar(bar); + CREATE INDEX bar_another_normal_id ON bar(bar DESC, fizz DESC); + CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + + CREATE FUNCTION check_content() RETURNS TRIGGER AS $$ + BEGIN + IF LENGTH(NEW.id) == 0 THEN + RAISE EXCEPTION 'content is empty'; + END IF; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_check_trigger + BEFORE UPDATE ON bar + FOR EACH ROW + EXECUTE FUNCTION check_content(); + `, + }, + }, + }, + { + name: "Drop partitioned table, Add partitioned table with local keys", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0), + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL, + PRIMARY KEY (foo, id) + ) PARTITION BY LIST(foo); + + CREATE TABLE foobar_1 PARTITION of foobar( + fizz NOT NULL + ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2'); + + -- partitioned indexes + CREATE INDEX foobar_normal_idx ON foobar(foo, fizz); + CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar); + -- local indexes + CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz); + + CREATE table bar( + id VARCHAR(255) PRIMARY KEY, + foo INT DEFAULT 0 CHECK (foo > 0 AND foo > bar), + bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, + fizz timestamptz DEFAULT CURRENT_TIMESTAMP, + buzz REAL NOT NULL CHECK (buzz IS NOT NULL) + ); + CREATE INDEX bar_normal_idx ON bar(bar); + CREATE INDEX bar_another_normal_id ON bar(bar, fizz); + CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + + CREATE TABLE fizz( + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE new_foobar( + id INT, + foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0), + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL + ) PARTITION BY LIST(foo); + + CREATE TABLE foobar_1 PARTITION of new_foobar( + fizz NOT NULL, + PRIMARY KEY (foo, bar) + ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2'); + + -- local indexes + CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz); + -- partitioned indexes + CREATE INDEX foobar_normal_idx ON new_foobar(foo, fizz); + CREATE UNIQUE INDEX foobar_unique_idx ON new_foobar(foo, bar); + + CREATE table bar( + id VARCHAR(255) PRIMARY KEY, + foo INT DEFAULT 0 CHECK (foo > 0 AND foo > bar), + bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, + fizz timestamptz DEFAULT CURRENT_TIMESTAMP, + buzz REAL NOT NULL CHECK (buzz IS NOT NULL) + ); + CREATE INDEX bar_normal_idx ON bar(bar); + CREATE INDEX bar_another_normal_id ON bar(bar, fizz); + CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + + CREATE TABLE fizz( + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + dataPackingExpectations: expectations{ + outputState: []string{ + ` + CREATE TABLE new_foobar( + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + id INT, + fizz BOOLEAN NOT NULL, + foo VARCHAR(255) DEFAULT 'some default' NOT NULL CHECK (LENGTH(foo) > 0) + ) PARTITION BY LIST(foo); + + CREATE TABLE foobar_1 PARTITION of new_foobar( + fizz NOT NULL, + PRIMARY KEY (foo, bar) + ) FOR VALUES IN ('foobar_1_val_1', 'foobar_1_val_2'); + + -- local indexes + CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz); + -- partitioned indexes + CREATE INDEX foobar_normal_idx ON new_foobar(foo, fizz); + CREATE UNIQUE INDEX foobar_unique_idx ON new_foobar(foo, bar); + + CREATE table bar( + id VARCHAR(255) PRIMARY KEY, + foo INT DEFAULT 0 CHECK (foo > 0 AND foo > bar), + bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, + fizz timestamptz DEFAULT CURRENT_TIMESTAMP, + buzz REAL NOT NULL CHECK (buzz IS NOT NULL) + ); + CREATE INDEX bar_normal_idx ON bar(bar); + CREATE INDEX bar_another_normal_id ON bar(bar, fizz); + CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + + CREATE TABLE fizz( + ); + `, + }, + }, + }, +} + +func (suite *acceptanceTestSuite) TestSchemaAcceptanceTestCases() { + suite.runTestCases(schemaAcceptanceTests) +} diff --git a/internal/migration_acceptance_tests/table_cases_test.go b/internal/migration_acceptance_tests/table_cases_test.go new file mode 100644 index 0000000..0c5c9a1 --- /dev/null +++ b/internal/migration_acceptance_tests/table_cases_test.go @@ -0,0 +1,273 @@ +package migration_acceptance_tests + +import ( + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var tableAcceptanceTestCases = []acceptanceTestCase{ + { + name: "No-op", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz), + foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL, + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL, + buzz REAL CHECK (buzz IS NOT NULL) + ); + CREATE INDEX normal_idx ON foobar(fizz); + CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY CHECK (id > 0), + foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL, + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL, + buzz REAL CHECK (buzz IS NOT NULL) + ); + CREATE INDEX normal_idx ON foobar(fizz); + CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar); + `, + }, + }, + { + name: "Create table", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz), + foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL, + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL, + buzz REAL CHECK (buzz IS NOT NULL) + ); + CREATE INDEX normal_idx ON foobar(fizz); + CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar); + `, + }, + dataPackingExpectations: expectations{ + outputState: []string{` + CREATE TABLE foobar( + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz), + buzz REAL CHECK (buzz IS NOT NULL), + fizz BOOLEAN NOT NULL, + foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL + ); + CREATE INDEX normal_idx ON foobar(fizz); + CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar); + `}, + }, + }, + { + name: "Create table with quoted names", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY, + "Foo" VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL CHECK (LENGTH("Foo") > 0), + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL + ); + CREATE INDEX normal_idx ON "Foobar" USING hash (fizz); + CREATE UNIQUE INDEX unique_idx ON "Foobar"("Foo" DESC, bar); + `, + }, + dataPackingExpectations: expectations{ + outputState: []string{ + ` + CREATE TABLE "Foobar"( + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + id INT PRIMARY KEY, + fizz BOOLEAN NOT NULL, + "Foo" VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL CHECK (LENGTH("Foo") > 0) + ); + CREATE INDEX normal_idx ON "Foobar" USING hash (fizz); + CREATE UNIQUE INDEX unique_idx ON "Foobar"("Foo" DESC, bar); + `, + }, + }, + }, + { + name: "Drop table", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz), + foo VARCHAR(255) COLLATE "C" DEFAULT '' NOT NULL, + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL, + buzz REAL CHECK (buzz IS NOT NULL) + ); + CREATE INDEX normal_idx ON foobar(fizz); + CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + newSchemaDDL: nil, + }, + { + name: "Drop a table with quoted names", + oldSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + id INT PRIMARY KEY, + "Foo" VARCHAR(255) COLLATE "C" DEFAULT '' NOT NULL CHECK (LENGTH("Foo") > 0), + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL + ); + CREATE INDEX normal_idx ON "Foobar"(fizz); + CREATE UNIQUE INDEX unique_idx ON "Foobar"("Foo", "bar"); + `, + }, + newSchemaDDL: nil, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + }, + { + name: "Alter table: New primary key, change column types, delete unique index, new index, validate check constraint", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz), + foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL, + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL, + buzz REAL, + fizzbuzz TEXT + ); + ALTER TABLE foobar ADD CONSTRAINT buzz_check CHECK (buzz IS NOT NULL) NOT VALID; + CREATE INDEX normal_idx ON foobar(fizz); + CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT CHECK (id > 0), + foo CHAR COLLATE "C" DEFAULT '5' NOT NULL PRIMARY KEY, + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL , + fizz BOOLEAN NOT NULL, + buzz REAL, + fizzbuzz TEXT COLLATE "POSIX" + ); + ALTER TABLE foobar ADD CONSTRAINT buzz_check CHECK (buzz IS NOT NULL); + CREATE INDEX normal_idx ON foobar(fizz); + CREATE INDEX other_idx ON foobar(bar); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, + { + name: "Alter table: New column, new primary key, alter column to nullable, drop column, drop index, drop check constraints", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz), + foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL, + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL, + buzz REAL CHECK (buzz IS NOT NULL) + ); + CREATE INDEX normal_idx ON foobar(fizz); + CREATE UNIQUE INDEX unique_idx ON foobar(foo DESC, bar); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo CHAR DEFAULT '5', + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + new_fizz DECIMAL(65, 10) DEFAULT 5.25 NOT NULL PRIMARY KEY + ); + CREATE INDEX other_idx ON foobar(bar); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + diff.MigrationHazardTypeDeletesData, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Alter table: effectively drop by changing everything", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY CHECK (id > 0), CHECK (id < buzz), + foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL, + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fizz BOOLEAN NOT NULL, + buzz REAL CHECK (buzz IS NOT NULL) + ); + CREATE INDEX normal_idx ON foobar USING hash (fizz); + CREATE UNIQUE INDEX unique_idx ON foobar(foo DESC, bar); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + new_id INT PRIMARY KEY CHECK (new_id > 0), CHECK (new_id < new_buzz), + new_foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL, + new_bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + new_fizz BOOLEAN NOT NULL, + new_buzz REAL CHECK (new_buzz IS NOT NULL) + ); + CREATE INDEX normal_idx ON foobar USING hash (new_fizz); + CREATE UNIQUE INDEX unique_idx ON foobar(new_foo DESC, new_bar); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeDeletesData, + diff.MigrationHazardTypeIndexDropped, + diff.MigrationHazardTypeIndexBuild, + }, + }, + { + name: "Alter table: translate BIGINT type to TIMESTAMP, set to not null, set default", + oldSchemaDDL: []string{ + ` + CREATE TABLE alexrhee_testing( + id INT PRIMARY KEY, + obj_attr__c_time BIGINT, + obj_attr__m_time BIGINT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE alexrhee_testing( + id INT PRIMARY KEY, + obj_attr__c_time TIMESTAMP NOT NULL, + obj_attr__m_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAcquiresAccessExclusiveLock, + diff.MigrationHazardTypeImpactsDatabasePerformance, + }, + }, +} + +func (suite *acceptanceTestSuite) TestTableAcceptanceTestCases() { + suite.runTestCases(tableAcceptanceTestCases) +} diff --git a/internal/migration_acceptance_tests/trigger_cases_test.go b/internal/migration_acceptance_tests/trigger_cases_test.go new file mode 100644 index 0000000..fd41da4 --- /dev/null +++ b/internal/migration_acceptance_tests/trigger_cases_test.go @@ -0,0 +1,702 @@ +package migration_acceptance_tests + +import ( + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var triggerAcceptanceTestCases = []acceptanceTestCase{ + { + name: "No-op", + oldSchemaDDL: []string{ + ` + CREATE TABLE foo ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foo + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + + CREATE FUNCTION check_content() RETURNS TRIGGER AS $$ + BEGIN + IF LENGTH(NEW.content) == 0 THEN + RAISE EXCEPTION 'content is empty'; + END IF; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_check_trigger + BEFORE UPDATE ON foo + FOR EACH ROW + EXECUTE FUNCTION check_content(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foo ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_update_trigger + BEFORE UPDATE ON foo + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + + CREATE FUNCTION check_content() RETURNS TRIGGER AS $$ + BEGIN + IF LENGTH(NEW.content) == 0 THEN + RAISE EXCEPTION 'content is empty'; + END IF; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_check_trigger + BEFORE UPDATE ON foo + FOR EACH ROW + EXECUTE FUNCTION check_content(); + `, + }, + }, + { + name: "Create trigger with quoted name", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Create triggers with quoted names on partitioned table and partition", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ) PARTITION BY LIST (content); + + CREATE TABLE "foobar 1" PARTITION OF "some foo" FOR VALUES IN ('foo_2'); + + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ) PARTITION BY LIST (content); + + CREATE TABLE "foobar 1" PARTITION OF "some foo" FOR VALUES IN ('foo_2'); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + + CREATE TRIGGER "some partition trigger" + BEFORE UPDATE ON "foobar 1" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Create two triggers depending on the same function", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + + CREATE TRIGGER "some other trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Create two triggers with the same name", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE TABLE "some other foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE TABLE "some other foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some other foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Drop trigger with quoted names", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + `, + }, + }, + { + name: "Drop triggers with quoted names on partitioned table and partition", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ) PARTITION BY LIST (content); + + CREATE TABLE "foobar 1" PARTITION OF "some foo" FOR VALUES IN ('foo_2'); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + + CREATE TRIGGER "some partition trigger" + BEFORE UPDATE ON "foobar 1" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ) PARTITION BY LIST (content); + + CREATE TABLE "foobar 1" PARTITION OF "some foo" FOR VALUES IN ('foo_2'); + + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Drop two triggers with the same name", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE TABLE "some other foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some other foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE TABLE "some other foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Alter trigger when clause", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (NEW.author != 'fizz') + EXECUTE PROCEDURE "increment version"(); + `, + }, + }, + { + name: "Alter trigger table", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + CREATE TABLE "some other foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + CREATE TABLE "some other foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some other foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + }, + { + name: "Change trigger function and keep old function", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `}, + newSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + + CREATE FUNCTION "decrement version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "decrement version"(); + `}, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Change trigger function and drop old function", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `}, + newSchemaDDL: []string{ + ` + CREATE TABLE "some foo" ( + id INTEGER PRIMARY KEY, + author TEXT, + content TEXT NOT NULL DEFAULT '', + version INT NOT NULL DEFAULT 0 + ); + + CREATE FUNCTION "decrement version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foo" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "decrement version"(); + `}, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "Trigger on re-created table is re-created", + oldSchemaDDL: []string{ + ` + CREATE TABLE "some foobar" ( + id INTEGER, + version INT NOT NULL DEFAULT 0, + author TEXT, + content TEXT NOT NULL DEFAULT '' + ) PARTITION BY LIST (content); + + CREATE TABLE "foobar 1" PARTITION OF "some foobar" FOR VALUES IN ('foo_2'); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some foobar" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + + CREATE TRIGGER "some partition trigger" + BEFORE UPDATE ON "foobar 1" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE "some other foobar" ( + id INTEGER, + version INT NOT NULL DEFAULT 0, + author TEXT, + content TEXT NOT NULL DEFAULT '' + ) PARTITION BY LIST (content); + + CREATE TABLE "foobar 1" PARTITION OF "some other foobar" FOR VALUES IN ('foo_2'); + + CREATE FUNCTION "increment version"() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER "some trigger" + BEFORE UPDATE ON "some other foobar" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + + CREATE TRIGGER "some partition trigger" + BEFORE UPDATE ON "foobar 1" + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE "increment version"(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + }, +} + +func (suite *acceptanceTestSuite) TestTriggerAcceptanceTestCases() { + suite.runTestCases(triggerAcceptanceTestCases) +} diff --git a/internal/pgdump/dump.go b/internal/pgdump/dump.go new file mode 100644 index 0000000..d63dac8 --- /dev/null +++ b/internal/pgdump/dump.go @@ -0,0 +1,55 @@ +package pgdump + +import ( + "errors" + "fmt" + "os/exec" + + "github.com/stripe/pg-schema-diff/internal/pgengine" +) + +// Parameter represents a parameter to be pg_dump. Don't use a type alias for a string slice +// because all parameters for pgdump should be explicitly added here +type Parameter struct { + values []string `explicit:"always"` +} + +func WithExcludeSchema(pattern string) Parameter { + return Parameter{ + values: []string{"--exclude-schema", pattern}, + } +} + +func WithSchemaOnly() Parameter { + return Parameter{ + values: []string{"--schema-only"}, + } +} + +// GetDump gets the pg_dump of the inputted database. +// It is only intended to be used for testing. You cannot securely pass passwords with this implementation, so it will +// only accept databases created for unit tests (spun up with the pgengine package) +// "pgdump" must be on the system's PATH +func GetDump(db *pgengine.DB, additionalParams ...Parameter) (string, error) { + pgDumpBinaryPath, err := exec.LookPath("pg_dump") + if err != nil { + return "", errors.New("pg_dump executable not found in path") + } + return GetDumpUsingBinary(pgDumpBinaryPath, db, additionalParams...) +} + +func GetDumpUsingBinary(pgDumpBinaryPath string, db *pgengine.DB, additionalParams ...Parameter) (string, error) { + params := []string{ + db.GetDSN(), + } + for _, param := range additionalParams { + params = append(params, param.values...) + } + + output, err := exec.Command(pgDumpBinaryPath, params...).CombinedOutput() + if err != nil { + return "", fmt.Errorf("running pg dump \noutput=%s\n: %w", output, err) + } + + return string(output), nil +} diff --git a/internal/pgdump/dump_test.go b/internal/pgdump/dump_test.go new file mode 100644 index 0000000..ff2f130 --- /dev/null +++ b/internal/pgdump/dump_test.go @@ -0,0 +1,53 @@ +package pgdump_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stripe/pg-schema-diff/internal/pgdump" + "github.com/stripe/pg-schema-diff/internal/pgengine" +) + +func TestGetDump(t *testing.T) { + pgEngine, err := pgengine.StartEngine() + require.NoError(t, err) + defer pgEngine.Close() + + db, err := pgEngine.CreateDatabase() + require.NoError(t, err) + defer db.DropDB() + + connPool, err := sql.Open("pgx", db.GetDSN()) + require.NoError(t, err) + defer connPool.Close() + + _, err = connPool.ExecContext(context.Background(), ` + CREATE TABLE foobar(foobar_id text); + + INSERT INTO foobar VALUES ('some-id'); + + CREATE SCHEMA test; + CREATE TABLE test.bar(bar_id text); + `) + require.NoError(t, err) + + dump, err := pgdump.GetDump(db) + require.NoError(t, err) + require.Contains(t, dump, "public.foobar") + require.Contains(t, dump, "test.bar") + require.Contains(t, dump, "some-id") + + onlySchemasDump, err := pgdump.GetDump(db, pgdump.WithSchemaOnly()) + require.NoError(t, err) + require.Contains(t, onlySchemasDump, "public.foobar") + require.Contains(t, onlySchemasDump, "test.bar") + require.NotContains(t, onlySchemasDump, "some-id") + + onlyPublicSchemaDump, err := pgdump.GetDump(db, pgdump.WithSchemaOnly(), pgdump.WithExcludeSchema("test")) + require.NoError(t, err) + require.Contains(t, onlyPublicSchemaDump, "public.foobar") + require.NotContains(t, onlyPublicSchemaDump, "test.bar") + require.NotContains(t, onlyPublicSchemaDump, "some-id") +} diff --git a/internal/pgengine/db.go b/internal/pgengine/db.go new file mode 100644 index 0000000..2b24105 --- /dev/null +++ b/internal/pgengine/db.go @@ -0,0 +1,61 @@ +package pgengine + +import ( + "database/sql" + "fmt" + + _ "github.com/jackc/pgx/v4/stdlib" +) + +type DB struct { + connOpts ConnectionOptions + + dropped bool +} + +func (d *DB) GetName() string { + return d.connOpts[ConnectionOptionDatabase] +} + +func (d *DB) GetConnOpts() ConnectionOptions { + return d.connOpts +} + +func (d *DB) GetDSN() string { + return d.GetConnOpts().ToDSN() +} + +// DropDB drops the database +func (d *DB) DropDB() error { + if d.dropped { + return nil + } + + // Use the pgDsn as we are dropping the test database + db, err := sql.Open("pgx", d.GetConnOpts().With(ConnectionOptionDatabase, "postgres").ToDSN()) + if err != nil { + return err + } + defer db.Close() + + // Disallow further connections to the test database, except for superusers + _, err = db.Exec(fmt.Sprintf("ALTER DATABASE \"%s\" CONNECTION LIMIT 0", d.GetName())) + if err != nil { + return err + } + + // Drop existing connections, so that we can drop the table + _, err = db.Exec("SELECT PG_TERMINATE_BACKEND(pid) FROM pg_stat_activity WHERE datname = $1", d.GetName()) + if err != nil { + return err + } + + // Finally, drop the table + _, err = db.Exec(fmt.Sprintf("DROP DATABASE \"%s\"", d.GetName())) + if err != nil { + return err + } + + d.dropped = true + return nil +} diff --git a/internal/pgengine/engine.go b/internal/pgengine/engine.go new file mode 100644 index 0000000..b5bcaa1 --- /dev/null +++ b/internal/pgengine/engine.go @@ -0,0 +1,235 @@ +package pgengine + +import ( + "database/sql" + "errors" + "fmt" + "os" + "os/exec" + "os/user" + "path" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + _ "github.com/jackc/pgx/v4/stdlib" +) + +const ( + port = 5432 + + maxConnAttemptsAtStartup = 10 + waitBetweenStartupConnAttempt = time.Second +) + +type ConnectionOption string + +const ( + ConnectionOptionDatabase ConnectionOption = "dbname" +) + +type ConnectionOptions map[ConnectionOption]string + +func (c ConnectionOptions) With(option ConnectionOption, value string) ConnectionOptions { + clone := make(ConnectionOptions) + for k, v := range c { + clone[k] = v + } + clone[option] = value + return clone +} + +func (c ConnectionOptions) ToDSN() string { + var pairs []string + for k, v := range c { + pairs = append(pairs, fmt.Sprintf("%s%s%s", k, "=", v)) + } + + return strings.Join(pairs, " ") +} + +type Engine struct { + user *user.User + + // for cleanup purposes + process *os.Process + dbPath string + sockPath string +} + +// StartEngine starts a postgres instance. This is useful for testing, where Postgres databases need to be spun up. +// "postgres" must be on the system's PATH, and the binary must be located in a directory containing "initdb" +func StartEngine() (*Engine, error) { + postgresPath, err := exec.LookPath("postgres") + if err != nil { + return nil, errors.New("postgres executable not found in path") + } + return StartEngineUsingPgDir(path.Dir(postgresPath)) +} + +func StartEngineUsingPgDir(pgDir string) (*Engine, error) { + currentUser, err := user.Current() + if err != nil { + return nil, err + } + + dbPath, err := os.MkdirTemp("", "postgresql-") + if err != nil { + return nil, err + } + + sockPath, err := os.MkdirTemp("", "pgsock-") + if err != nil { + return nil, err + } + + if err := initDB(currentUser, path.Join(pgDir, "initdb"), dbPath); err != nil { + return nil, err + } + + process, err := startServer(path.Join(pgDir, "postgres"), dbPath, sockPath) + if err != nil { + // Cleanup temporary directories that were created + os.RemoveAll(dbPath) + os.RemoveAll(sockPath) + return nil, err + } + + pgEngine := &Engine{ + dbPath: dbPath, + sockPath: sockPath, + user: currentUser, + process: process, + } + if err := pgEngine.waitTillServingTraffic(maxConnAttemptsAtStartup, waitBetweenStartupConnAttempt); err != nil { + pgEngine.Close() + return nil, fmt.Errorf("waiting till server can serve traffic: %w", err) + } + return pgEngine, nil +} + +func initDB(currentUser *user.User, initDbPath, dbPath string) error { + cmd := exec.Command(initDbPath, []string{ + "-U", currentUser.Username, + "-D", dbPath, + "-A", "trust", + }...) + + output, err := cmd.CombinedOutput() + if err != nil { + outputStr := string(output) + + var tip string + line := strings.Repeat("=", 95) + if strings.Contains(outputStr, "request for a shared memory segment exceeded your kernel's SHMALL parameter") { + tip = line + "\n Run 'sudo sysctl -w kern.sysv.shmall=16777216' to solve this issue \n" + line + "\n" + } else if strings.Contains(outputStr, "could not create shared memory segment: No space left on device") { + tip = line + "\n Use the ipcs and ipcrm commands to clear the shared memory \n" + line + "\n" + } + + return fmt.Errorf("error running initdb: %w\n%s\n%s", err, outputStr, tip) + } + return nil +} + +func startServer(pgBinaryPath, dbPath, sockPath string) (*os.Process, error) { + cmd := exec.Command(pgBinaryPath, []string{ + "-D", dbPath, + "-k", sockPath, + "-p", strconv.Itoa(port), + "-h", ""}...) + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("starting postgres server instance: %w", err) + } + + return cmd.Process, nil +} + +func (e *Engine) waitTillServingTraffic(maxAttempts int, timeBetweenAttempts time.Duration) error { + var mostRecentErr error + for i := 0; i < maxAttempts; i++ { + mostRecentErr = e.testIfInstanceServingTraffic() + if mostRecentErr == nil { + return nil + } + time.Sleep(timeBetweenAttempts) + } + return fmt.Errorf("unable to establish connection to postgres instance. most recent error: %w", mostRecentErr) +} + +func (e *Engine) testIfInstanceServingTraffic() error { + dsn := e.GetPostgresDatabaseConnOpts().ToDSN() + db, err := sql.Open("pgx", dsn) + if err != nil { + return err + } + + if err := db.Ping(); err != nil { + db.Close() + return err + } + // Don't `defer` the call to DropDB(), since we want to return the error if any + return db.Close() +} + +func (e *Engine) GetPostgresDatabaseConnOpts() ConnectionOptions { + result := make(map[ConnectionOption]string) + result[ConnectionOptionDatabase] = "postgres" + result["host"] = e.sockPath + result["port"] = strconv.Itoa(port) + result["sslmode"] = "disable" + + return result +} + +func (e *Engine) GetPostgresDatabaseDSN() string { + return e.GetPostgresDatabaseConnOpts().ToDSN() +} + +func (e *Engine) Close() error { + // Make best effort attempt to clean up everything + e.process.Signal(os.Interrupt) + e.process.Wait() + os.RemoveAll(e.dbPath) + os.RemoveAll(e.dbPath) + + return nil +} + +func (e *Engine) CreateDatabase() (*DB, error) { + uuid, err := uuid.NewRandom() + if err != nil { + return nil, fmt.Errorf("generating uuid: %w", err) + } + testDBName := fmt.Sprintf("pgtestdb_%v", uuid.String()) + + testDb, err := e.CreateDatabaseWithName(testDBName) + if err != nil { + return nil, err + } + + return testDb, err +} + +func (e *Engine) CreateDatabaseWithName(name string) (*DB, error) { + dsn := e.GetPostgresDatabaseConnOpts().With(ConnectionOptionDatabase, "postgres").ToDSN() + db, err := sql.Open("pgx", dsn) + if err != nil { + return nil, err + } + defer db.Close() + + _, err = db.Exec(fmt.Sprintf("CREATE DATABASE \"%s\"", name)) + if err != nil { + return nil, err + } + + return &DB{ + connOpts: e.GetPostgresDatabaseConnOpts().With(ConnectionOptionDatabase, name), + }, nil +} diff --git a/internal/pgengine/engine_test.go b/internal/pgengine/engine_test.go new file mode 100644 index 0000000..a5da12c --- /dev/null +++ b/internal/pgengine/engine_test.go @@ -0,0 +1,105 @@ +package pgengine_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stripe/pg-schema-diff/internal/pgengine" + + "github.com/stretchr/testify/require" +) + +func TestPgEngine(t *testing.T) { + engine, err := pgengine.StartEngine() + require.NoError(t, err) + defer func() { + // Drops should be idempotent + require.NoError(t, engine.Close()) + require.NoError(t, engine.Close()) + }() + + unnamedDb, err := engine.CreateDatabase() + require.NoError(t, err) + assertDatabaseIsValid(t, unnamedDb, "") + + namedDb, err := engine.CreateDatabaseWithName("some-name") + require.NoError(t, err) + assertDatabaseIsValid(t, namedDb, "some-name") + + // Assert no extra databases were created + assert.ElementsMatch(t, getAllDatabaseNames(t, engine), []string{unnamedDb.GetName(), "some-name", "postgres", "template0", "template1"}) + + // Hold open a connection before we try to drop the database. The drop should still pass, despite the open + // connection + connPool, err := sql.Open("pgx", unnamedDb.GetDSN()) + require.NoError(t, err) + defer connPool.Close() + conn, err := connPool.Conn(context.Background()) + require.NoError(t, err) + defer conn.Close() + + // Drops should be idempotent + require.NoError(t, unnamedDb.DropDB()) + require.NoError(t, unnamedDb.DropDB()) + require.NoError(t, namedDb.DropDB()) + require.NoError(t, namedDb.DropDB()) + + // Assert only the expected databases were dropped + assert.ElementsMatch(t, getAllDatabaseNames(t, engine), []string{"postgres", "template0", "template1"}) +} + +func assertDatabaseIsValid(t *testing.T, db *pgengine.DB, expectedName string) { + connPool, err := sql.Open("pgx", db.GetDSN()) + require.NoError(t, err) + defer func() { + require.NoError(t, connPool.Close()) + }() + + var fetchedName string + require.NoError(t, connPool.QueryRowContext(context.Background(), "SELECT current_database();"). + Scan(&fetchedName)) + + assert.Equal(t, db.GetName(), fetchedName) + if len(expectedName) > 0 { + assert.Equal(t, fetchedName, expectedName) + } + + // Validate writing and reading works + _, err = connPool.ExecContext(context.Background(), `CREATE TABLE foobar(id serial NOT NULL)`) + require.NoError(t, err) + + _, err = connPool.ExecContext(context.Background(), `INSERT INTO foobar DEFAULT VALUES`) + require.NoError(t, err) + + var id string + require.NoError(t, connPool.QueryRowContext(context.Background(), "SELECT * FROM foobar LIMIT 1;"). + Scan(&id)) + assert.NotEmptyf(t, t, id) +} + +func getAllDatabaseNames(t *testing.T, engine *pgengine.Engine) []string { + conn, err := sql.Open("pgx", engine.GetPostgresDatabaseDSN()) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close()) + }() + + rows, err := conn.QueryContext(context.Background(), "SELECT datname FROM pg_database") + require.NoError(t, err) + defer rows.Close() + + // Iterate over the rows and put the database names in an array + var dbNames []string + for rows.Next() { + var dbName string + if err := rows.Scan(&dbName); err != nil { + require.NoError(t, err) + } + dbNames = append(dbNames, dbName) + } + require.NoError(t, rows.Err()) + + return dbNames +} diff --git a/internal/pgidentifier/identifier.go b/internal/pgidentifier/identifier.go new file mode 100644 index 0000000..d50d47c --- /dev/null +++ b/internal/pgidentifier/identifier.go @@ -0,0 +1,12 @@ +package pgidentifier + +import "regexp" + +var ( + // SimpleIdentifierRegex matches identifiers in Postgres that require no quotes + SimpleIdentifierRegex = regexp.MustCompile("^[a-z_][a-z0-9_$]*$") +) + +func IsSimpleIdentifier(val string) bool { + return SimpleIdentifierRegex.MatchString(val) +} diff --git a/internal/pgidentifier/identifier_test.go b/internal/pgidentifier/identifier_test.go new file mode 100644 index 0000000..4efc384 --- /dev/null +++ b/internal/pgidentifier/identifier_test.go @@ -0,0 +1,56 @@ +package pgidentifier + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_IsSimpleIdentifier(t *testing.T) { + for _, tc := range []struct { + name string + input string + expected bool + }{ + { + name: "starts with letter", + input: "foo", + expected: true, + }, + { + name: "starts with underscore", + input: "_foo", + expected: true, + }, + { + name: "start with number", + input: "1foo", + expected: false, + }, + { + name: "contains all possible characters", + input: "some_1119$_", + expected: true, + }, + { + name: "empty", + input: "", + expected: false, + }, + { + name: "contains upper case letter", + input: "fooBar", + expected: false, + }, + { + name: "contains spaces", + input: "foo bar", + expected: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + actual := IsSimpleIdentifier(tc.input) + assert.Equal(t, tc.expected, actual) + }) + } +} diff --git a/internal/queries/README.md b/internal/queries/README.md new file mode 100644 index 0000000..74f4f1f --- /dev/null +++ b/internal/queries/README.md @@ -0,0 +1,15 @@ +# Expected workflow + +1. Add a SQL query/statement in `queries.sql`, following the other examples. +2. Run `make sqlc` using the same version of sqlc as in `build/Dockerfile.codegen` + +`sqlc` decides what the return-type of the generated method should be based on the `:exec` suffix (documentation [here](https://docs.sqlc.dev/en/latest/reference/query-annotations.html)): + - `:exec` will only tell you whether the query succeeded: `error` + - `:execrows` will tell you how many rows were affected: `(int64, error)` + - `:one` will give you back a single struct: `(Author, error)` + - `:many` will give you back a slice of structs: `([]Author, error)` + + +It is configured by the `sqlc.yaml` file. You can read docs about the various +options it supports +[here](https://docs.sqlc.dev/en/latest/reference/config.html). diff --git a/internal/queries/dml.sql.go b/internal/queries/dml.sql.go new file mode 100644 index 0000000..9df9a79 --- /dev/null +++ b/internal/queries/dml.sql.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package queries + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/queries/models.sql.go b/internal/queries/models.sql.go new file mode 100644 index 0000000..2a05a43 --- /dev/null +++ b/internal/queries/models.sql.go @@ -0,0 +1,7 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package queries + +import () diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql new file mode 100644 index 0000000..5e90293 --- /dev/null +++ b/internal/queries/queries.sql @@ -0,0 +1,121 @@ +-- name: GetTables :many +SELECT c.oid AS oid, + c.relname AS table_name, + COALESCE(parent_c.relname, '') AS parent_table_name, + COALESCE(parent_namespace.nspname, '') AS parent_table_schema_name, + (CASE + WHEN c.relkind = 'p' THEN pg_catalog.pg_get_partkeydef(c.oid) + ELSE '' + END)::text + AS partition_key_def, + (CASE + WHEN c.relispartition THEN pg_catalog.pg_get_expr(c.relpartbound, c.oid) + ELSE '' + END)::text AS partition_for_values +FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_inherits inherits ON inherits.inhrelid = c.oid + LEFT JOIN pg_catalog.pg_class parent_c ON inherits.inhparent = parent_c.oid + LEFT JOIN pg_catalog.pg_namespace as parent_namespace ON parent_c.relnamespace = parent_namespace.oid +WHERE c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public') + AND (c.relkind = 'r' OR c.relkind = 'p'); + +-- name: GetColumnsForTable :many +SELECT a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type, + COALESCE(coll.collname, '') AS collation_name, + COALESCE(collation_namespace.nspname, '') AS collation_schema_name, + COALESCE(pg_catalog.pg_get_expr(d.adbin, d.adrelid), '')::TEXT AS default_value, + a.attnotnull AS is_not_null, + a.attlen AS column_size +FROM pg_catalog.pg_attribute a + LEFT JOIN pg_catalog.pg_attrdef d ON (d.adrelid = a.attrelid AND d.adnum = a.attnum) + LEFT JOIN pg_catalog.pg_collation coll ON coll.oid = a.attcollation + LEFT JOIN pg_catalog.pg_namespace collation_namespace ON collation_namespace.oid = coll.collnamespace +WHERE a.attrelid = $1 + AND a.attnum > 0 + AND NOT a.attisdropped +ORDER BY a.attnum; + +-- name: GetIndexes :many +SELECT c.oid AS oid, + c.relname as index_name, + (SELECT c1.relname AS table_name FROM pg_catalog.pg_class c1 WHERE c1.oid = i.indrelid), + pg_catalog.pg_get_indexdef(c.oid)::TEXT as def_stmt, + COALESCE(con.conname, '') as constraint_name, + i.indisvalid as index_is_valid, + i.indisprimary as index_is_pk, + i.indisunique AS index_is_unique, + COALESCE(parent_c.relname, '') as parent_index_name, + COALESCE(parent_namespace.nspname, '') as parent_index_schema_name +FROM pg_catalog.pg_class c + INNER JOIN pg_catalog.pg_index i ON (i.indexrelid = c.oid) + LEFT JOIN pg_catalog.pg_constraint con ON (con.conindid = c.oid) + LEFT JOIN pg_catalog.pg_inherits inherits ON (c.oid = inherits.inhrelid) + LEFT JOIN pg_catalog.pg_class parent_c ON (inherits.inhparent = parent_c.oid) + LEFT JOIN pg_catalog.pg_namespace as parent_namespace ON parent_c.relnamespace = parent_namespace.oid +WHERE c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public') + AND (c.relkind = 'i' OR c.relkind = 'I'); + +-- name: GetColumnsForIndex :many +SELECT a.attname AS column_name +FROM pg_catalog.pg_attribute a +WHERE a.attrelid = $1 + AND a.attnum > 0 +ORDER BY a.attnum; + +-- name: GetCheckConstraints :many +SELECT pg_constraint.oid, + conname as name, + pg_class.relname as table_name, + pg_catalog.pg_get_expr(conbin, conrelid) as expression, + convalidated as is_valid, + connoinherit as is_not_inheritable +FROM pg_catalog.pg_constraint + JOIN pg_catalog.pg_class ON pg_constraint.conrelid = pg_class.oid +WHERE pg_class.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public') + AND contype = 'c' + AND pg_constraint.conislocal; + +-- name: GetFunctions :many +SELECT proc.oid, + proname as func_name, + pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments, + proc_namespace.nspname as func_schema_name, + pg_catalog.pg_get_functiondef(proc.oid) as func_def, + proc_lang.lanname as func_lang +FROM pg_catalog.pg_proc proc + JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid + JOIN pg_catalog.pg_language proc_lang ON proc_lang.oid = proc.prolang +WHERE proc_namespace.nspname = 'public' + AND proc.prokind = 'f' + -- Exclude functions belonging to extensions + AND NOT EXISTS(SELECT depend.objid FROM pg_catalog.pg_depend depend WHERE deptype = 'e' AND depend.objid = proc.oid); + +-- name: GetDependsOnFunctions :many +SELECT proc.proname as func_name, + pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments, + proc_namespace.nspname as func_schema_name +FROM pg_catalog.pg_depend depend + JOIN pg_catalog.pg_proc proc ON depend.refobjid = proc.oid + JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid +WHERE depend.objid = $1 + AND depend.deptype = 'n'; + +-- name: GetTriggers :many +SELECT trig.tgname as trigger_name, + owning_c.relname as owning_table_name, + owning_c_namespace.nspname as owning_table_schema_name, + proc.proname as func_name, + pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments, + proc_namespace.nspname as func_schema_name, + pg_catalog.pg_get_triggerdef(trig.oid) as trigger_def +FROM pg_catalog.pg_trigger trig + JOIN pg_catalog.pg_class owning_c ON trig.tgrelid = owning_c.oid + JOIN pg_catalog.pg_namespace owning_c_namespace ON owning_c.relnamespace = owning_c_namespace.oid + JOIN pg_catalog.pg_proc proc ON trig.tgfoid = proc.oid + JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid +WHERE proc_namespace.nspname = 'public' + AND owning_c_namespace.nspname = 'public' + AND trig.tgparentid = 0 + AND NOT trig.tgisinternal; + diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go new file mode 100644 index 0000000..57e62dd --- /dev/null +++ b/internal/queries/queries.sql.go @@ -0,0 +1,437 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 +// source: queries.sql + +package queries + +import ( + "context" +) + +const getCheckConstraints = `-- name: GetCheckConstraints :many +SELECT pg_constraint.oid, + conname as name, + pg_class.relname as table_name, + pg_catalog.pg_get_expr(conbin, conrelid) as expression, + convalidated as is_valid, + connoinherit as is_not_inheritable +FROM pg_catalog.pg_constraint + JOIN pg_catalog.pg_class ON pg_constraint.conrelid = pg_class.oid +WHERE pg_class.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public') + AND contype = 'c' + AND pg_constraint.conislocal +` + +type GetCheckConstraintsRow struct { + Oid interface{} + Name string + TableName string + Expression string + IsValid bool + IsNotInheritable bool +} + +func (q *Queries) GetCheckConstraints(ctx context.Context) ([]GetCheckConstraintsRow, error) { + rows, err := q.db.QueryContext(ctx, getCheckConstraints) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetCheckConstraintsRow + for rows.Next() { + var i GetCheckConstraintsRow + if err := rows.Scan( + &i.Oid, + &i.Name, + &i.TableName, + &i.Expression, + &i.IsValid, + &i.IsNotInheritable, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getColumnsForIndex = `-- name: GetColumnsForIndex :many +SELECT a.attname AS column_name +FROM pg_catalog.pg_attribute a +WHERE a.attrelid = $1 + AND a.attnum > 0 +ORDER BY a.attnum +` + +func (q *Queries) GetColumnsForIndex(ctx context.Context, attrelid interface{}) ([]string, error) { + rows, err := q.db.QueryContext(ctx, getColumnsForIndex, attrelid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var column_name string + if err := rows.Scan(&column_name); err != nil { + return nil, err + } + items = append(items, column_name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getColumnsForTable = `-- name: GetColumnsForTable :many +SELECT a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type, + COALESCE(coll.collname, '') AS collation_name, + COALESCE(collation_namespace.nspname, '') AS collation_schema_name, + COALESCE(pg_catalog.pg_get_expr(d.adbin, d.adrelid), '')::TEXT AS default_value, + a.attnotnull AS is_not_null, + a.attlen AS column_size +FROM pg_catalog.pg_attribute a + LEFT JOIN pg_catalog.pg_attrdef d ON (d.adrelid = a.attrelid AND d.adnum = a.attnum) + LEFT JOIN pg_catalog.pg_collation coll ON coll.oid = a.attcollation + LEFT JOIN pg_catalog.pg_namespace collation_namespace ON collation_namespace.oid = coll.collnamespace +WHERE a.attrelid = $1 + AND a.attnum > 0 + AND NOT a.attisdropped +ORDER BY a.attnum +` + +type GetColumnsForTableRow struct { + ColumnName string + ColumnType string + CollationName string + CollationSchemaName string + DefaultValue string + IsNotNull bool + ColumnSize int16 +} + +func (q *Queries) GetColumnsForTable(ctx context.Context, attrelid interface{}) ([]GetColumnsForTableRow, error) { + rows, err := q.db.QueryContext(ctx, getColumnsForTable, attrelid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetColumnsForTableRow + for rows.Next() { + var i GetColumnsForTableRow + if err := rows.Scan( + &i.ColumnName, + &i.ColumnType, + &i.CollationName, + &i.CollationSchemaName, + &i.DefaultValue, + &i.IsNotNull, + &i.ColumnSize, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getDependsOnFunctions = `-- name: GetDependsOnFunctions :many +SELECT proc.proname as func_name, + pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments, + proc_namespace.nspname as func_schema_name +FROM pg_catalog.pg_depend depend + JOIN pg_catalog.pg_proc proc ON depend.refobjid = proc.oid + JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid +WHERE depend.objid = $1 + AND depend.deptype = 'n' +` + +type GetDependsOnFunctionsRow struct { + FuncName string + FuncIdentityArguments string + FuncSchemaName string +} + +func (q *Queries) GetDependsOnFunctions(ctx context.Context, objid interface{}) ([]GetDependsOnFunctionsRow, error) { + rows, err := q.db.QueryContext(ctx, getDependsOnFunctions, objid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetDependsOnFunctionsRow + for rows.Next() { + var i GetDependsOnFunctionsRow + if err := rows.Scan(&i.FuncName, &i.FuncIdentityArguments, &i.FuncSchemaName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getFunctions = `-- name: GetFunctions :many +SELECT proc.oid, + proname as func_name, + pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments, + proc_namespace.nspname as func_schema_name, + pg_catalog.pg_get_functiondef(proc.oid) as func_def, + proc_lang.lanname as func_lang +FROM pg_catalog.pg_proc proc + JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid + JOIN pg_catalog.pg_language proc_lang ON proc_lang.oid = proc.prolang +WHERE proc_namespace.nspname = 'public' + AND proc.prokind = 'f' + -- Exclude functions belonging to extensions + AND NOT EXISTS(SELECT depend.objid FROM pg_catalog.pg_depend depend WHERE deptype = 'e' AND depend.objid = proc.oid) +` + +type GetFunctionsRow struct { + Oid interface{} + FuncName string + FuncIdentityArguments string + FuncSchemaName string + FuncDef string + FuncLang string +} + +func (q *Queries) GetFunctions(ctx context.Context) ([]GetFunctionsRow, error) { + rows, err := q.db.QueryContext(ctx, getFunctions) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetFunctionsRow + for rows.Next() { + var i GetFunctionsRow + if err := rows.Scan( + &i.Oid, + &i.FuncName, + &i.FuncIdentityArguments, + &i.FuncSchemaName, + &i.FuncDef, + &i.FuncLang, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getIndexes = `-- name: GetIndexes :many +SELECT c.oid AS oid, + c.relname as index_name, + (SELECT c1.relname AS table_name FROM pg_catalog.pg_class c1 WHERE c1.oid = i.indrelid), + pg_catalog.pg_get_indexdef(c.oid)::TEXT as def_stmt, + COALESCE(con.conname, '') as constraint_name, + i.indisvalid as index_is_valid, + i.indisprimary as index_is_pk, + i.indisunique AS index_is_unique, + COALESCE(parent_c.relname, '') as parent_index_name, + COALESCE(parent_namespace.nspname, '') as parent_index_schema_name +FROM pg_catalog.pg_class c + INNER JOIN pg_catalog.pg_index i ON (i.indexrelid = c.oid) + LEFT JOIN pg_catalog.pg_constraint con ON (con.conindid = c.oid) + LEFT JOIN pg_catalog.pg_inherits inherits ON (c.oid = inherits.inhrelid) + LEFT JOIN pg_catalog.pg_class parent_c ON (inherits.inhparent = parent_c.oid) + LEFT JOIN pg_catalog.pg_namespace as parent_namespace ON parent_c.relnamespace = parent_namespace.oid +WHERE c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public') + AND (c.relkind = 'i' OR c.relkind = 'I') +` + +type GetIndexesRow struct { + Oid interface{} + IndexName string + TableName string + DefStmt string + ConstraintName string + IndexIsValid bool + IndexIsPk bool + IndexIsUnique bool + ParentIndexName string + ParentIndexSchemaName string +} + +func (q *Queries) GetIndexes(ctx context.Context) ([]GetIndexesRow, error) { + rows, err := q.db.QueryContext(ctx, getIndexes) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetIndexesRow + for rows.Next() { + var i GetIndexesRow + if err := rows.Scan( + &i.Oid, + &i.IndexName, + &i.TableName, + &i.DefStmt, + &i.ConstraintName, + &i.IndexIsValid, + &i.IndexIsPk, + &i.IndexIsUnique, + &i.ParentIndexName, + &i.ParentIndexSchemaName, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getTables = `-- name: GetTables :many +SELECT c.oid AS oid, + c.relname AS table_name, + COALESCE(parent_c.relname, '') AS parent_table_name, + COALESCE(parent_namespace.nspname, '') AS parent_table_schema_name, + (CASE + WHEN c.relkind = 'p' THEN pg_catalog.pg_get_partkeydef(c.oid) + ELSE '' + END)::text + AS partition_key_def, + (CASE + WHEN c.relispartition THEN pg_catalog.pg_get_expr(c.relpartbound, c.oid) + ELSE '' + END)::text AS partition_for_values +FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_inherits inherits ON inherits.inhrelid = c.oid + LEFT JOIN pg_catalog.pg_class parent_c ON inherits.inhparent = parent_c.oid + LEFT JOIN pg_catalog.pg_namespace as parent_namespace ON parent_c.relnamespace = parent_namespace.oid +WHERE c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'public') + AND (c.relkind = 'r' OR c.relkind = 'p') +` + +type GetTablesRow struct { + Oid interface{} + TableName string + ParentTableName string + ParentTableSchemaName string + PartitionKeyDef string + PartitionForValues string +} + +func (q *Queries) GetTables(ctx context.Context) ([]GetTablesRow, error) { + rows, err := q.db.QueryContext(ctx, getTables) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetTablesRow + for rows.Next() { + var i GetTablesRow + if err := rows.Scan( + &i.Oid, + &i.TableName, + &i.ParentTableName, + &i.ParentTableSchemaName, + &i.PartitionKeyDef, + &i.PartitionForValues, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getTriggers = `-- name: GetTriggers :many +SELECT trig.tgname as trigger_name, + owning_c.relname as owning_table_name, + owning_c_namespace.nspname as owning_table_schema_name, + proc.proname as func_name, + pg_catalog.pg_get_function_identity_arguments(proc.oid) as func_identity_arguments, + proc_namespace.nspname as func_schema_name, + pg_catalog.pg_get_triggerdef(trig.oid) as trigger_def +FROM pg_catalog.pg_trigger trig + JOIN pg_catalog.pg_class owning_c ON trig.tgrelid = owning_c.oid + JOIN pg_catalog.pg_namespace owning_c_namespace ON owning_c.relnamespace = owning_c_namespace.oid + JOIN pg_catalog.pg_proc proc ON trig.tgfoid = proc.oid + JOIN pg_catalog.pg_namespace proc_namespace ON proc.pronamespace = proc_namespace.oid +WHERE proc_namespace.nspname = 'public' + AND owning_c_namespace.nspname = 'public' + AND trig.tgparentid = 0 + AND NOT trig.tgisinternal +` + +type GetTriggersRow struct { + TriggerName string + OwningTableName string + OwningTableSchemaName string + FuncName string + FuncIdentityArguments string + FuncSchemaName string + TriggerDef string +} + +func (q *Queries) GetTriggers(ctx context.Context) ([]GetTriggersRow, error) { + rows, err := q.db.QueryContext(ctx, getTriggers) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetTriggersRow + for rows.Next() { + var i GetTriggersRow + if err := rows.Scan( + &i.TriggerName, + &i.OwningTableName, + &i.OwningTableSchemaName, + &i.FuncName, + &i.FuncIdentityArguments, + &i.FuncSchemaName, + &i.TriggerDef, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/queries/sqlc.yaml b/internal/queries/sqlc.yaml new file mode 100644 index 0000000..d1c568a --- /dev/null +++ b/internal/queries/sqlc.yaml @@ -0,0 +1,10 @@ +version: 1 +packages: + - path: "." + name: "queries" + engine: "postgresql" + queries: "queries.sql" + schema: "system_tables.sql" + output_db_file_name: "dml.sql.go" + output_models_file_name: "models.sql.go" + output_querier_file_name: "querier.sql.go" diff --git a/internal/queries/system_tables.sql b/internal/queries/system_tables.sql new file mode 100644 index 0000000..b3cf533 --- /dev/null +++ b/internal/queries/system_tables.sql @@ -0,0 +1,385 @@ +-- postgres=# \d pg_catalog.pg_namespace; +-- Table "pg_catalog.pg_namespace" +-- Column | Type | Collation | Nullable | Default +-- ----------+-----------+-----------+----------+--------- +-- oid | oid | | not null | +-- nspname | name | | not null | +-- nspowner | oid | | not null | +-- nspacl | aclitem[] | | | + +CREATE TABLE pg_catalog.pg_namespace +( + oid OID NOT NULL, + nspname TEXT NOT NULL +); + + +-- postgres=# \d pg_catalog.pg_class; +-- Table "pg_catalog.pg_class" +-- Column | Type | Collation | Nullable | Default +-- ---------------------+--------------+-----------+----------+--------- +-- oid | oid | | not null | +-- relname | name | | not null | +-- relnamespace | oid | | not null | +-- reltype | oid | | not null | +-- reloftype | oid | | not null | +-- relowner | oid | | not null | +-- relam | oid | | not null | +-- relfilenode | oid | | not null | +-- reltablespace | oid | | not null | +-- relpages | integer | | not null | +-- reltuples | real | | not null | +-- relallvisible | integer | | not null | +-- reltoastrelid | oid | | not null | +-- relhasindex | boolean | | not null | +-- relisshared | boolean | | not null | +-- relpersistence | "char" | | not null | +-- relkind | "char" | | not null | +-- relnatts | smallint | | not null | +-- relchecks | smallint | | not null | +-- relhasrules | boolean | | not null | +-- relhastriggers | boolean | | not null | +-- relhassubclass | boolean | | not null | +-- relrowsecurity | boolean | | not null | +-- relforcerowsecurity | boolean | | not null | +-- relispopulated | boolean | | not null | +-- relreplident | "char" | | not null | +-- relispartition | boolean | | not null | +-- relrewrite | oid | | not null | +-- relfrozenxid | xid | | not null | +-- relminmxid | xid | | not null | +-- relacl | aclitem[] | | | +-- reloptions | text[] | C | | +-- relpartbound | pg_node_tree | C | | +-- Indexes: +-- "pg_class_oid_index" PRIMARY KEY, btree (oid) +-- "pg_class_relname_nsp_index" UNIQUE CONSTRAINT, btree (relname, relnamespace) +-- "pg_class_tblspc_relfilenode_index" btree (reltablespace, relfilenode) + +CREATE TABLE pg_catalog.pg_class +( + oid OID NOT NULL, + relname TEXT NOT NULL, + relnamespace INT NOT NULL, + reltype INT NOT NULL +); + +-- postgres=# \d pg_catalog.pg_inherits +-- Column | Type | Collation | Nullable | Default +-- ------------------+---------+-----------+----------+--------- +-- inhrelid | oid | | not null | +-- inhparent | oid | | not null | +-- inhseqno | integer | | not null | +-- inhdetachpending | boolean | | not null | +-- Indexes: +-- "pg_inherits_relid_seqno_index" PRIMARY KEY, btree (inhrelid, inhseqno) +-- "pg_inherits_parent_index" btree (inhparent) +-- +CREATE TABLE pg_catalog.pg_inherits +( + inhrelid OID NOT NULL, + inhparent OID NOT NULL +); + +-- postgres=# \d pg_catalog.pg_attribute; +-- Table "pg_catalog.pg_attribute" +-- Column | Type | Collation | Nullable | Default +-- ----------------+-----------+-----------+----------+--------- +-- attrelid | oid | | not null | +-- attname | name | | not null | +-- atttypid | oid | | not null | +-- attstattarget | integer | | not null | +-- attlen | smallint | | not null | +-- attnum | smallint | | not null | +-- attndims | integer | | not null | +-- attcacheoff | integer | | not null | +-- atttypmod | integer | | not null | +-- attbyval | boolean | | not null | +-- attalign | "char" | | not null | +-- attstorage | "char" | | not null | +-- attcompression | "char" | | not null | +-- attnotnull | boolean | | not null | +-- atthasdef | boolean | | not null | +-- atthasmissing | boolean | | not null | +-- attidentity | "char" | | not null | +-- attgenerated | "char" | | not null | +-- attisdropped | boolean | | not null | +-- attislocal | boolean | | not null | +-- attinhcount | integer | | not null | +-- attcollation | oid | | not null | +-- attacl | aclitem[] | | | +-- attoptions | text[] | C | | +-- attfdwoptions | text[] | C | | +-- attmissingval | anyarray | | | +-- Indexes: +-- "pg_attribute_relid_attnum_index" PRIMARY KEY, btree (attrelid, attnum) +-- "pg_attribute_relid_attnam_index" UNIQUE CONSTRAINT, btree (attrelid, attname) + +CREATE TABLE pg_catalog.pg_attribute +( + attrelid OID NOT NULL, + attname TEXT NOT NULL, + attlen SMALLINT NOT NULL, + attnum SMALLINT NOT NULL, + atttypid INTEGER NOT NULL, + atttypmod INTEGER NOT NULL, + attnotnull BOOLEAN NOT NULL, + attcollation OID NOT NULL +); + +-- postgres=# \d pg_catalog.pg_collation; +-- Table "pg_catalog.pg_collation" +-- Column | Type | Collation | Nullable | Default +-- ---------------------+---------+-----------+----------+--------- +-- oid | oid | | not null | +-- collname | name | | not null | +-- collnamespace | oid | | not null | +-- collowner | oid | | not null | +-- collprovider | "char" | | not null | +-- collisdeterministic | boolean | | not null | +-- collencoding | integer | | not null | +-- collcollate | name | | not null | +-- collctype | name | | not null | +-- collversion | text | C | | +-- Indexes: +-- "pg_collation_oid_index" PRIMARY KEY, btree (oid) +-- "pg_collation_name_enc_nsp_index" UNIQUE CONSTRAINT, btree (collname, collencoding, collnamespace) +CREATE TABLE pg_catalog.pg_collation +( + oid OID NOT NULL, + collname TEXT NOT NULL, + collnamespace OID NOT NULL +); + +-- postgres=# \d pg_index; +-- Table "pg_catalog.pg_index" +-- Column | Type | Collation | Nullable | Default +-- ----------------+--------------+-----------+----------+--------- +-- indexrelid | oid | | not null | +-- indrelid | oid | | not null | +-- indnatts | smallint | | not null | +-- indnkeyatts | smallint | | not null | +-- indisunique | boolean | | not null | +-- indisprimary | boolean | | not null | +-- indisexclusion | boolean | | not null | +-- indimmediate | boolean | | not null | +-- indisclustered | boolean | | not null | +-- indisvalid | boolean | | not null | +-- indcheckxmin | boolean | | not null | +-- indisready | boolean | | not null | +-- indislive | boolean | | not null | +-- indisreplident | boolean | | not null | +-- indkey | int2vector | | not null | +-- indcollation | oidvector | | not null | +-- indclass | oidvector | | not null | +-- indoption | int2vector | | not null | +-- indexprs | pg_node_tree | C | | +-- indpred | pg_node_tree | C | | +-- Indexes: +-- "pg_index_indexrelid_index" PRIMARY KEY, btree (indexrelid) +-- "pg_index_indrelid_index" btree (indrelid) +CREATE TABLE pg_catalog.pg_index +( + indexrelid OID NOT NULL, + indrelid OID NOT NULL, + indisunique BOOLEAN NOT NULL, + indisprimary BOOLEAN NOT NULL, + indisvalid BOOLEAN NOT NULL +); + +-- postgres=# \d pg_catalog.pg_attrdef; +-- Table "pg_catalog.pg_attrdef" +-- Column | Type | Collation | Nullable | Default +-- ---------+--------------+-----------+----------+--------- +-- oid | oid | | not null | +-- adrelid | oid | | not null | +-- adnum | smallint | | not null | +-- adbin | pg_node_tree | C | not null | +-- Indexes: +-- "pg_attrdef_oid_index" PRIMARY KEY, btree (oid) +-- "pg_attrdef_adrelid_adnum_index" UNIQUE CONSTRAINT, btree (adrelid, adnum) +CREATE TABLE pg_catalog.pg_attrdef +( + adrelid OID NOT NULL, + admin INT NOT NULL, + adbin PG_NODE_TREE NOT NULL +); + +-- postgres=# \d pg_catalog.pg_depend +-- Table "pg_catalog.pg_depend" +-- Column | Type | Collation | Nullable | Default +-- -------------+---------+-----------+----------+--------- +-- classid | oid | | not null | +-- objid | oid | | not null | +-- objsubid | integer | | not null | +-- refclassid | oid | | not null | +-- refobjid | oid | | not null | +-- refobjsubid | integer | | not null | +-- deptype | "char" | | not null | +-- Indexes: +-- "pg_depend_depender_index" btree (classid, objid, objsubid) +-- "pg_depend_reference_index" btree (refclassid, refobjid, refobjsubid) +CREATE TABLE pg_catalog.pg_depend +( + objid OID, + refobjid OID, + deptype char +); + + +-- postgres=# \d pg_catalog.pg_constraint +-- Table "pg_catalog.pg_constraint" +-- Column | Type | Collation | Nullable | Default +-- ---------------+--------------+-----------+----------+--------- +-- oid | oid | | not null | +-- conname | name | | not null | +-- connamespace | oid | | not null | +-- contype | "char" | | not null | +-- condeferrable | boolean | | not null | +-- condeferred | boolean | | not null | +-- convalidated | boolean | | not null | +-- conrelid | oid | | not null | +-- contypid | oid | | not null | +-- conindid | oid | | not null | +-- conparentid | oid | | not null | +-- confrelid | oid | | not null | +-- confupdtype | "char" | | not null | +-- confdeltype | "char" | | not null | +-- confmatchtype | "char" | | not null | +-- conislocal | boolean | | not null | +-- coninhcount | integer | | not null | +-- connoinherit | boolean | | not null | +-- conkey | smallint[] | | | +-- confkey | smallint[] | | | +-- conpfeqop | oid[] | | | +-- conppeqop | oid[] | | | +-- conffeqop | oid[] | | | +-- conexclop | oid[] | | | +-- conbin | pg_node_tree | C | | +-- Indexes: +-- "pg_constraint_oid_index" PRIMARY KEY, btree (oid) +-- "pg_constraint_conname_nsp_index" btree (conname, connamespace) +-- "pg_constraint_conparentid_index" btree (conparentid) +-- "pg_constraint_conrelid_contypid_conname_index" UNIQUE CONSTRAINT, btree (conrelid, contypid, conname) +-- "pg_constraint_contypid_index" btree (contypid) +CREATE TABLE pg_catalog.pg_constraint +( + oid OID NOT NULL, + conname TEXT NOT NULL, + conindid OID NOT NULL, + contype CHAR NOT NULL, + condeferrable BOOLEAN NOT NULL, + condeferred BOOLEAN NOT NULL, + convalidated BOOLEAN NOT NULL, + conrelid OID NOT NULL, + conislocal BOOLEAN NOT NULL, + connoinherit BOOLEAN NOT NULL, + conbin PG_NODE_TREE NOT NULL +); + +-- postgres=# \d pg_catalog.pg_proc +-- Table "pg_catalog.pg_proc" +-- Column | Type | Collation | Nullable | Default +-- -----------------+--------------+-----------+----------+--------- +-- oid | oid | | not null | +-- proname | name | | not null | +-- pronamespace | oid | | not null | +-- proowner | oid | | not null | +-- prolang | oid | | not null | +-- procost | real | | not null | +-- prorows | real | | not null | +-- provariadic | oid | | not null | +-- prosupport | regproc | | not null | +-- prokind | "char" | | not null | +-- prosecdef | boolean | | not null | +-- proleakproof | boolean | | not null | +-- proisstrict | boolean | | not null | +-- proretset | boolean | | not null | +-- provolatile | "char" | | not null | +-- proparallel | "char" | | not null | +-- pronargs | smallint | | not null | +-- pronargdefaults | smallint | | not null | +-- prorettype | oid | | not null | +-- proargtypes | oidvector | | not null | +-- proallargtypes | oid[] | | | +-- proargmodes | "char"[] | | | +-- proargnames | text[] | C | | +-- proargdefaults | pg_node_tree | C | | +-- protrftypes | oid[] | | | +-- prosrc | text | C | not null | +-- probin | text | C | | +-- prosqlbody | pg_node_tree | C | | +-- proconfig | text[] | C | | +-- proacl | aclitem[] | | | +-- Indexes: +-- "pg_proc_oid_index" PRIMARY KEY, btree (oid) +-- "pg_proc_proname_args_nsp_index" UNIQUE CONSTRAINT, btree (proname, proargtypes, pronamespace) +CREATE TABLE pg_catalog.pg_proc +( + oid OID NOT NULL, + proname TEXT NOT NULL, + pronamespace OID NOT NULL, + prolang OID NOT NULL, + prokind CHAR NOT NULL +); + + +-- postgres=# \d pg_catalog.pg_language +-- Table "pg_catalog.pg_language" +-- Column | Type | Collation | Nullable | Default +-- ---------------+-----------+-----------+----------+--------- +-- oid | oid | | not null | +-- lanname | name | | not null | +-- lanowner | oid | | not null | +-- lanispl | boolean | | not null | +-- lanpltrusted | boolean | | not null | +-- lanplcallfoid | oid | | not null | +-- laninline | oid | | not null | +-- lanvalidator | oid | | not null | +-- lanacl | aclitem[] | | | +-- Indexes: +-- "pg_language_oid_index" PRIMARY KEY, btree (oid) +-- "pg_language_name_index" UNIQUE CONSTRAINT, btree (lanname) +CREATE TABLE pg_catalog.pg_language +( + oid OID NOT NULL, + lanname TEXT NOT NULL +); + + +-- postgres=# \d pg_catalog.pg_trigger; +-- Table "pg_catalog.pg_trigger" +-- Column | Type | Collation | Nullable | Default +-- ----------------+--------------+-----------+----------+--------- +-- oid | oid | | not null | +-- tgrelid | oid | | not null | +-- tgparentid | oid | | not null | +-- tgname | name | | not null | +-- tgfoid | oid | | not null | +-- tgtype | smallint | | not null | +-- tgenabled | "char" | | not null | +-- tgisinternal | boolean | | not null | +-- tgconstrrelid | oid | | not null | +-- tgconstrindid | oid | | not null | +-- tgconstraint | oid | | not null | +-- tgdeferrable | boolean | | not null | +-- tginitdeferred | boolean | | not null | +-- tgnargs | smallint | | not null | +-- tgattr | int2vector | | not null | +-- tgargs | bytea | | not null | +-- tgqual | pg_node_tree | C | | +-- tgoldtable | name | | | +-- tgnewtable | name | | | +-- Indexes: +-- "pg_trigger_oid_index" PRIMARY KEY, btree (oid) +-- "pg_trigger_tgrelid_tgname_index" UNIQUE CONSTRAINT, btree (tgrelid, tgname) +-- "pg_trigger_tgconstraint_index" btree (tgconstraint) +CREATE TABLE pg_catalog.pg_trigger +( + oid OID NOT NULL, + tgrelid OID NOT NULL, + tgparentid OID NOT NULL, + tgname TEXT NOT NULL, + tfoid OID NOT NULL, + tgisinternal BOOLEAN NOT NULL +); diff --git a/internal/schema/schema.go b/internal/schema/schema.go new file mode 100644 index 0000000..131cfc0 --- /dev/null +++ b/internal/schema/schema.go @@ -0,0 +1,500 @@ +package schema + +import ( + "context" + "fmt" + "regexp" + "sort" + + "github.com/mitchellh/hashstructure/v2" + "github.com/stripe/pg-schema-diff/internal/queries" +) + +type ( + // Object represents a resource in a schema (table, column, index...) + Object interface { + // GetName is used to identify the old and new versions of a schema object between the old and new schemas + // If the name is not present in the old schema objects list, then it is added + // If the name is not present in the new schemas objects list, then it is removed + // Otherwise, it has persisted across two schemas and is possibly altered + GetName() string + } + + // SchemaQualifiedName represents a schema object name scoped within a schema + SchemaQualifiedName struct { + SchemaName string + // EscapedName is the name of the object. It should already be escaped + // We take an escaped name because there are weird exceptions, like functions, where we can't just + // surround the name in quotes + EscapedName string + } +) + +func (o SchemaQualifiedName) GetName() string { + return o.GetFQEscapedName() +} + +// GetFQEscapedName gets the fully-qualified, escaped name of the schema object, including the schema name +func (o SchemaQualifiedName) GetFQEscapedName() string { + return fmt.Sprintf("%s.%s", EscapeIdentifier(o.SchemaName), o.EscapedName) +} + +func (o SchemaQualifiedName) IsEmpty() bool { + return len(o.SchemaName) == 0 +} + +type Schema struct { + // Name refers to the name of the schema. Ultimately, schema objects can cut across + // schemas, e.g., a partition of a table can exist in a different table. Thus, we're probably + // going to delete this Name attribute soon, once multi-schema is supported + Name string + Tables []Table + Indexes []Index + + Functions []Function + Triggers []Trigger +} + +func (s Schema) GetName() string { + return s.Name +} + +// Normalize normalizes the schema (alphabetically sorts tables and columns in tables) +// Useful for hashing and testing +func (s Schema) Normalize() Schema { + var normTables []Table + for _, table := range sortSchemaObjectsByName(s.Tables) { + // Don't normalize columns order. their order is derived from the postgres catalogs + // (relevant to data packing) + var normCheckConstraints []CheckConstraint + for _, checkConstraint := range sortSchemaObjectsByName(table.CheckConstraints) { + checkConstraint.DependsOnFunctions = sortSchemaObjectsByName(checkConstraint.DependsOnFunctions) + normCheckConstraints = append(normCheckConstraints, checkConstraint) + } + table.CheckConstraints = normCheckConstraints + normTables = append(normTables, table) + } + s.Tables = normTables + + s.Indexes = sortSchemaObjectsByName(s.Indexes) + + var normFunctions []Function + for _, function := range sortSchemaObjectsByName(s.Functions) { + function.DependsOnFunctions = sortSchemaObjectsByName(function.DependsOnFunctions) + normFunctions = append(normFunctions, function) + } + s.Functions = normFunctions + + s.Triggers = sortSchemaObjectsByName(s.Triggers) + + return s +} + +// sortSchemaObjectsByName returns a (copied) sorted list of schema objects. +func sortSchemaObjectsByName[S Object](vals []S) []S { + clonedVals := make([]S, len(vals)) + copy(clonedVals, vals) + sort.Slice(clonedVals, func(i, j int) bool { + return clonedVals[i].GetName() < clonedVals[j].GetName() + }) + return clonedVals +} + +func (s Schema) Hash() (string, error) { + // alternatively, we can print the struct as a string and hash it + hashVal, err := hashstructure.Hash(s.Normalize(), hashstructure.FormatV2, nil) + if err != nil { + return "", fmt.Errorf("hashing schema: %w", err) + } + return fmt.Sprintf("%x", hashVal), nil +} + +type Table struct { + Name string + Columns []Column + CheckConstraints []CheckConstraint + + // PartitionKeyDef is the output of Pg function pg_get_partkeydef: + // PARTITION BY $PartitionKeyDef + // If empty, then the table is not partitioned + PartitionKeyDef string + + ParentTableName string + ForValues string +} + +func (t Table) IsPartitioned() bool { + return len(t.PartitionKeyDef) > 0 +} + +func (t Table) IsPartition() bool { + return len(t.ForValues) > 0 +} + +func (t Table) GetName() string { + return t.Name +} + +type Column struct { + Name string + Type string + Collation SchemaQualifiedName + // If the column has a default value, this will be a SQL string representing that value. + // Examples: + // ''::text + // CURRENT_TIMESTAMP + // If empty, indicates that there is no default value. + Default string + IsNullable bool + + // Size is the number of bytes required to store the value. + // It is used for data-packing purposes + Size int // +} + +func (c Column) GetName() string { + return c.Name +} + +func (c Column) IsCollated() bool { + return !c.Collation.IsEmpty() +} + +var ( + // The first matching group is the "CREATE [UNIQUE] INDEX ". UNIQUE is an optional match + // because only UNIQUE indices will have the UNIQUE keyword in their pg_get_indexdef statement + // + // The third matching group is the rest of the statement + idxToConcurrentlyRegex = regexp.MustCompile("^(CREATE (UNIQUE )?INDEX )(.*)$") +) + +// GetIndexDefStatement is the output of pg_getindexdef. It is a `CREATE INDEX` statement that will re-create +// the index. This statement does not contain `CONCURRENTLY`. +// For unique indexes, it does contain `UNIQUE` +// For partitioned tables, it does contain `ONLY` +type GetIndexDefStatement string + +func (i GetIndexDefStatement) ToCreateIndexConcurrently() (string, error) { + if !idxToConcurrentlyRegex.MatchString(string(i)) { + return "", fmt.Errorf("%s follows an unexpected structure", i) + } + return idxToConcurrentlyRegex.ReplaceAllString(string(i), "${1}CONCURRENTLY ${3}"), nil +} + +type Index struct { + TableName string + Name string + Columns []string + IsInvalid bool + IsPk bool + IsUnique bool + // ConstraintName is the name of the constraint associated with an index. Empty string if no associated constraint. + // Once we need support for constraints not associated with indexes, we'll add a + // Constraint schema object and starting fetching constraints directly + ConstraintName string + + // GetIndexDefStmt is the output of pg_getindexdef + GetIndexDefStmt GetIndexDefStatement + + // ParentIdxName is the name of the parent index if the index is a partition of an index + ParentIdxName string +} + +func (i Index) GetName() string { + return i.Name +} + +func (i Index) IsPartitionOfIndex() bool { + return len(i.ParentIdxName) > 0 +} + +type CheckConstraint struct { + Name string + Expression string + IsValid bool + IsInheritable bool + DependsOnFunctions []SchemaQualifiedName +} + +func (c CheckConstraint) GetName() string { + return c.Name +} + +type Function struct { + SchemaQualifiedName + // FunctionDef is the statement required to completely (re)create + // the function, as returned by `pg_get_functiondef`. It is a CREATE OR REPLACE + // statement + FunctionDef string + // Language is the language of the function. This is relevant in determining if we + // can track the dependencies of the function (or not) + Language string + DependsOnFunctions []SchemaQualifiedName +} + +var ( + // The first matching group is the "CREATE ". The second matching group is the rest of the statement + triggerToOrReplaceRegex = regexp.MustCompile("^(CREATE )(.*)$") +) + +// GetTriggerDefStatement is the output of pg_get_triggerdef. It is a `CREATE TRIGGER` statement that will create +// the trigger. This statement does not contain `OR REPLACE` +type GetTriggerDefStatement string + +func (g GetTriggerDefStatement) ToCreateOrReplace() (string, error) { + if !triggerToOrReplaceRegex.MatchString(string(g)) { + return "", fmt.Errorf("%s follows an unexpected structure", g) + } + return triggerToOrReplaceRegex.ReplaceAllString(string(g), "${1}OR REPLACE ${2}"), nil +} + +type Trigger struct { + EscapedName string + OwningTable SchemaQualifiedName + // OwningTableUnescapedName lets us be backwards compatible with the TableSQLVertexGenerator, which + // currently uses the unescaped name as the vertex id. This will be removed once the TableSQLVertexGenerator + // is migrated to use SchemaQualifiedName + OwningTableUnescapedName string + Function SchemaQualifiedName + // GetTriggerDefStmt is the statement required to completely (re)create the trigger, as returned + // by pg_get_triggerdef + GetTriggerDefStmt GetTriggerDefStatement +} + +func (t Trigger) GetName() string { + return t.OwningTable.GetFQEscapedName() + "_" + t.EscapedName +} + +// GetPublicSchema fetches the "public" schema. It is a non-atomic operation +func GetPublicSchema(ctx context.Context, db queries.DBTX) (Schema, error) { + q := queries.New(db) + + tables, err := fetchTables(ctx, q) + if err != nil { + return Schema{}, fmt.Errorf("fetchTables: %w", err) + } + + indexes, err := fetchIndexes(ctx, q) + if err != nil { + return Schema{}, fmt.Errorf("fetchIndexes: %w", err) + } + + functions, err := fetchFunctions(ctx, q) + if err != nil { + return Schema{}, fmt.Errorf("fetchFunctions: %w", err) + } + + triggers, err := fetchTriggers(ctx, q) + if err != nil { + return Schema{}, fmt.Errorf("fetchTriggers: %w", err) + } + + return Schema{ + Name: "public", + Tables: tables, + Indexes: indexes, + Functions: functions, + Triggers: triggers, + }, nil +} + +func fetchTables(ctx context.Context, q *queries.Queries) ([]Table, error) { + rawTables, err := q.GetTables(ctx) + if err != nil { + return nil, fmt.Errorf("GetTables(): %w", err) + } + + tablesToCheckConsMap, err := fetchCheckConsAndBuildTableToCheckConsMap(ctx, q) + if err != nil { + return nil, fmt.Errorf("fetchCheckConsAndBuildTableToCheckConsMap: %w", err) + } + + var tables []Table + for _, table := range rawTables { + if len(table.ParentTableName) > 0 && table.ParentTableSchemaName != "public" { + return nil, fmt.Errorf( + "table %s has parent table in schema %s. only parent tables in public schema are supported: %w", + table.TableName, + table.ParentTableSchemaName, + err, + ) + } + + rawColumns, err := q.GetColumnsForTable(ctx, table.Oid) + if err != nil { + return nil, fmt.Errorf("GetColumnsForTable(%s): %w", table.Oid, err) + } + var columns []Column + for _, column := range rawColumns { + collation := SchemaQualifiedName{} + if len(column.CollationName) > 0 { + collation = SchemaQualifiedName{ + EscapedName: EscapeIdentifier(column.CollationName), + SchemaName: column.CollationSchemaName, + } + } + + columns = append(columns, Column{ + Name: column.ColumnName, + Type: column.ColumnType, + Collation: collation, + IsNullable: !column.IsNotNull, + // If the column has a default value, this will be a SQL string representing that value. + // Examples: + // ''::text + // CURRENT_TIMESTAMP + // If empty, indicates that there is no default value. + Default: column.DefaultValue, + Size: int(column.ColumnSize), + }) + } + + tables = append(tables, Table{ + Name: table.TableName, + Columns: columns, + CheckConstraints: tablesToCheckConsMap[table.TableName], + + PartitionKeyDef: table.PartitionKeyDef, + + ParentTableName: table.ParentTableName, + ForValues: table.PartitionForValues, + }) + } + return tables, nil +} + +// fetchCheckConsAndBuildTableToCheckConsMap fetches the check constraints and builds a map of table name to the check +// constraints within the table +func fetchCheckConsAndBuildTableToCheckConsMap(ctx context.Context, q *queries.Queries) (map[string][]CheckConstraint, error) { + rawCheckCons, err := q.GetCheckConstraints(ctx) + if err != nil { + return nil, fmt.Errorf("GetCheckConstraints: %w", err) + } + + result := make(map[string][]CheckConstraint) + for _, cc := range rawCheckCons { + dependsOnFunctions, err := fetchDependsOnFunctions(ctx, q, cc.Oid) + if err != nil { + return nil, fmt.Errorf("fetchDependsOnFunctions(%s): %w", cc.Oid, err) + } + + checkCon := CheckConstraint{ + Name: cc.Name, + Expression: cc.Expression, + IsValid: cc.IsValid, + IsInheritable: !cc.IsNotInheritable, + DependsOnFunctions: dependsOnFunctions, + } + result[cc.TableName] = append(result[cc.TableName], checkCon) + } + + return result, nil +} + +// fetchIndexes fetches the indexes We fetch all indexes at once to minimize number of queries, since each index needs +// to fetch columns +func fetchIndexes(ctx context.Context, q *queries.Queries) ([]Index, error) { + rawIndexes, err := q.GetIndexes(ctx) + if err != nil { + return nil, fmt.Errorf("GetColumnsInPublicSchema: %w", err) + } + + var indexes []Index + for _, rawIndex := range rawIndexes { + rawColumns, err := q.GetColumnsForIndex(ctx, rawIndex.Oid) + if err != nil { + return nil, fmt.Errorf("GetColumnsForIndex(%s): %w", rawIndex.Oid, err) + } + + indexes = append(indexes, Index{ + TableName: rawIndex.TableName, + Name: rawIndex.IndexName, + Columns: rawColumns, + GetIndexDefStmt: GetIndexDefStatement(rawIndex.DefStmt), + IsInvalid: !rawIndex.IndexIsValid, + IsPk: rawIndex.IndexIsPk, + IsUnique: rawIndex.IndexIsUnique, + ConstraintName: rawIndex.ConstraintName, + ParentIdxName: rawIndex.ParentIndexName, + }) + } + + return indexes, nil +} + +// fetchFunctions fetches the functions required to +func fetchFunctions(ctx context.Context, q *queries.Queries) ([]Function, error) { + rawFunctions, err := q.GetFunctions(ctx) + if err != nil { + return nil, fmt.Errorf("GetFunctions: %w", err) + } + + var functions []Function + for _, rawFunction := range rawFunctions { + dependsOnFunctions, err := fetchDependsOnFunctions(ctx, q, rawFunction.Oid) + if err != nil { + return nil, fmt.Errorf("fetchDependsOnFunctions(%s): %w", rawFunction.Oid, err) + } + + functions = append(functions, Function{ + SchemaQualifiedName: buildFuncName(rawFunction.FuncName, rawFunction.FuncIdentityArguments, rawFunction.FuncSchemaName), + FunctionDef: rawFunction.FuncDef, + Language: rawFunction.FuncLang, + DependsOnFunctions: dependsOnFunctions, + }) + } + + return functions, nil +} + +func fetchDependsOnFunctions(ctx context.Context, q *queries.Queries, oid any) ([]SchemaQualifiedName, error) { + dependsOnFunctions, err := q.GetDependsOnFunctions(ctx, oid) + if err != nil { + return nil, err + } + + var functionNames []SchemaQualifiedName + for _, rawFunction := range dependsOnFunctions { + functionNames = append(functionNames, buildFuncName(rawFunction.FuncName, rawFunction.FuncIdentityArguments, rawFunction.FuncSchemaName)) + } + + return functionNames, nil +} + +func fetchTriggers(ctx context.Context, q *queries.Queries) ([]Trigger, error) { + rawTriggers, err := q.GetTriggers(ctx) + if err != nil { + return nil, fmt.Errorf("GetTriggers: %w", err) + } + + var triggers []Trigger + for _, rawTrigger := range rawTriggers { + triggers = append(triggers, Trigger{ + EscapedName: EscapeIdentifier(rawTrigger.TriggerName), + OwningTable: buildNameFromUnescaped(rawTrigger.OwningTableName, rawTrigger.OwningTableSchemaName), + OwningTableUnescapedName: rawTrigger.OwningTableName, + Function: buildFuncName(rawTrigger.FuncName, rawTrigger.FuncIdentityArguments, rawTrigger.FuncSchemaName), + GetTriggerDefStmt: GetTriggerDefStatement(rawTrigger.TriggerDef), + }) + } + + return triggers, nil +} + +func buildFuncName(name, identityArguments, schemaName string) SchemaQualifiedName { + return SchemaQualifiedName{ + SchemaName: schemaName, + EscapedName: fmt.Sprintf("\"%s\"(%s)", name, identityArguments), + } +} + +func buildNameFromUnescaped(unescapedName, schemaName string) SchemaQualifiedName { + return SchemaQualifiedName{ + EscapedName: EscapeIdentifier(unescapedName), + SchemaName: schemaName, + } +} + +func EscapeIdentifier(name string) string { + return fmt.Sprintf("\"%s\"", name) +} diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go new file mode 100644 index 0000000..d025cad --- /dev/null +++ b/internal/schema/schema_test.go @@ -0,0 +1,886 @@ +package schema_test + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/jackc/pgx/v4/stdlib" + "github.com/kr/pretty" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stripe/pg-schema-diff/internal/pgengine" + "github.com/stripe/pg-schema-diff/internal/schema" +) + +type testCase struct { + name string + ddl []string + expectedSchema schema.Schema + expectedHash string + expectedErrIs error +} + +var ( + defaultCollation = schema.SchemaQualifiedName{ + EscapedName: `"default"`, + SchemaName: "pg_catalog", + } + cCollation = schema.SchemaQualifiedName{ + EscapedName: `"C"`, + SchemaName: "pg_catalog", + } + + testCases = []*testCase{ + { + name: "Simple test", + ddl: []string{` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + increment(a); + + CREATE TABLE foo ( + id INTEGER PRIMARY KEY, + author TEXT COLLATE "C", + content TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP CHECK (created_at > CURRENT_TIMESTAMP - interval '1 month') NO INHERIT, + version INT NOT NULL DEFAULT 0, + CHECK ( function_with_dependencies(id, id) > 0) + ); + + ALTER TABLE foo ADD CONSTRAINT author_check CHECK (author IS NOT NULL AND LENGTH(author) > 0) NO INHERIT NOT VALID; + CREATE INDEX some_idx ON foo USING hash (content); + CREATE UNIQUE INDEX some_unique_idx ON foo (created_at DESC, author ASC); + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_trigger + BEFORE UPDATE ON foo + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + `}, + expectedHash: "9648c294aed76ef6", + expectedSchema: schema.Schema{ + Name: "public", + Tables: []schema.Table{ + { + Name: "foo", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Size: 4}, + {Name: "author", Type: "text", IsNullable: true, Size: -1, Collation: cCollation}, + {Name: "content", Type: "text", Default: "''::text", Size: -1, Collation: defaultCollation}, + {Name: "created_at", Type: "timestamp without time zone", Default: "CURRENT_TIMESTAMP", Size: 8}, + {Name: "version", Type: "integer", Default: "0", Size: 4}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "author_check", Expression: "((author IS NOT NULL) AND (length(author) > 0))"}, + {Name: "foo_created_at_check", Expression: "(created_at > (CURRENT_TIMESTAMP - '1 mon'::interval))", IsValid: true}, + { + Name: "foo_id_check", + Expression: "(function_with_dependencies(id, id) > 0)", + IsValid: true, + IsInheritable: true, + DependsOnFunctions: []schema.SchemaQualifiedName{ + {EscapedName: "\"function_with_dependencies\"(a integer, b integer)", SchemaName: "public"}, + }, + }, + }, + }, + }, + Indexes: []schema.Index{ + { + TableName: "foo", + Name: "foo_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON public.foo USING btree (id)", + }, + { + TableName: "foo", + Name: "some_idx", Columns: []string{"content"}, + GetIndexDefStmt: "CREATE INDEX some_idx ON public.foo USING hash (content)", + }, + { + TableName: "foo", + Name: "some_unique_idx", Columns: []string{"created_at", "author"}, IsPk: false, IsUnique: true, + GetIndexDefStmt: "CREATE UNIQUE INDEX some_unique_idx ON public.foo USING btree (created_at DESC, author)", + }, + }, + Functions: []schema.Function{ + { + SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"add\"(a integer, b integer)", SchemaName: "public"}, + FunctionDef: "CREATE OR REPLACE FUNCTION public.add(a integer, b integer)\n RETURNS integer\n LANGUAGE sql\n IMMUTABLE STRICT\nRETURN (a + b)\n", + Language: "sql", + }, + { + SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"function_with_dependencies\"(a integer, b integer)", SchemaName: "public"}, + FunctionDef: "CREATE OR REPLACE FUNCTION public.function_with_dependencies(a integer, b integer)\n RETURNS integer\n LANGUAGE sql\n IMMUTABLE STRICT\nRETURN (add(a, b) + increment(a))\n", + Language: "sql", + DependsOnFunctions: []schema.SchemaQualifiedName{ + {EscapedName: "\"add\"(a integer, b integer)", SchemaName: "public"}, + {EscapedName: "\"increment\"(i integer)", SchemaName: "public"}, + }, + }, + { + SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"increment\"(i integer)", SchemaName: "public"}, + FunctionDef: "CREATE OR REPLACE FUNCTION public.increment(i integer)\n RETURNS integer\n LANGUAGE plpgsql\nAS $function$\n\t\t\t\t\tBEGIN\n\t\t\t\t\t\t\tRETURN i + 1;\n\t\t\t\t\tEND;\n\t\t\t$function$\n", + Language: "plpgsql", + }, + { + SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"}, + FunctionDef: "CREATE OR REPLACE FUNCTION public.increment_version()\n RETURNS trigger\n LANGUAGE plpgsql\nAS $function$\n\t\t\t\tBEGIN\n\t\t\t\t\tNEW.version = OLD.version + 1;\n\t\t\t\t\tRETURN NEW;\n\t\t\t\tEND;\n\t\t\t$function$\n", + Language: "plpgsql", + }, + }, + Triggers: []schema.Trigger{ + { + EscapedName: "\"some_trigger\"", + OwningTable: schema.SchemaQualifiedName{EscapedName: "\"foo\"", SchemaName: "public"}, + OwningTableUnescapedName: "foo", + Function: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"}, + GetTriggerDefStmt: "CREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()", + }, + }, + }, + }, + { + name: "Simple partition test", + ddl: []string{` + CREATE TABLE foo ( + id INTEGER CHECK (id > 0), + author TEXT COLLATE "C", + content TEXT DEFAULT '', + genre VARCHAR(256) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP CHECK (created_at > CURRENT_TIMESTAMP - interval '1 month'), + PRIMARY KEY (author, id) + ) PARTITION BY LIST (author); + ALTER TABLE foo ADD CONSTRAINT author_check CHECK (author IS NOT NULL AND LENGTH(author) > 0) NOT VALID; + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_trigger + BEFORE UPDATE ON foo + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + + CREATE TABLE foo_1 PARTITION OF foo( + content NOT NULL + ) FOR VALUES IN ('some author 1'); + CREATE TABLE foo_2 PARTITION OF foo FOR VALUES IN ('some author 2'); + CREATE TABLE foo_3 PARTITION OF foo FOR VALUES IN ('some author 3'); + + -- partitioned indexes + CREATE INDEX some_partitioned_idx ON foo USING hash(author); + CREATE UNIQUE INDEX some_unique_partitioned_idx ON foo(author, created_at DESC); + CREATE INDEX some_invalid_idx ON ONLY foo(author, genre); + + -- local indexes + CREATE UNIQUE INDEX foo_1_local_idx ON foo_1(author DESC, id); + CREATE UNIQUE INDEX foo_2_local_idx ON foo_2(author, content); + CREATE UNIQUE INDEX foo_3_local_idx ON foo_3(author, created_at); + + CREATE TRIGGER some_partition_trigger + BEFORE UPDATE ON foo_1 + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + + `}, + expectedHash: "651c9229cd8373f0", + expectedSchema: schema.Schema{ + Name: "public", + Tables: []schema.Table{ + { + Name: "foo", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Size: 4}, + {Name: "author", Type: "text", Size: -1, Collation: cCollation}, + {Name: "content", Type: "text", Default: "''::text", IsNullable: true, Size: -1, Collation: defaultCollation}, + {Name: "genre", Type: "character varying(256)", Size: -1, Collation: defaultCollation}, + {Name: "created_at", Type: "timestamp without time zone", Default: "CURRENT_TIMESTAMP", Size: 8}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "author_check", Expression: "((author IS NOT NULL) AND (length(author) > 0))", IsInheritable: true}, + {Name: "foo_created_at_check", Expression: "(created_at > (CURRENT_TIMESTAMP - '1 mon'::interval))", IsValid: true, IsInheritable: true}, + {Name: "foo_id_check", Expression: "(id > 0)", IsValid: true, IsInheritable: true}, + }, + PartitionKeyDef: "LIST (author)", + }, + { + ParentTableName: "foo", + Name: "foo_1", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Size: 4}, + {Name: "author", Type: "text", Size: -1, Collation: cCollation}, + {Name: "content", Type: "text", Default: "''::text", Size: -1, Collation: defaultCollation}, + {Name: "genre", Type: "character varying(256)", Size: -1, Collation: defaultCollation}, + {Name: "created_at", Type: "timestamp without time zone", Default: "CURRENT_TIMESTAMP", Size: 8}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some author 1')", + }, + { + ParentTableName: "foo", + Name: "foo_2", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Size: 4}, + {Name: "author", Type: "text", Size: -1, Collation: cCollation}, + {Name: "content", Type: "text", Default: "''::text", IsNullable: true, Size: -1, Collation: defaultCollation}, + {Name: "genre", Type: "character varying(256)", Size: -1, Collation: defaultCollation}, + {Name: "created_at", Type: "timestamp without time zone", Default: "CURRENT_TIMESTAMP", Size: 8}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some author 2')", + }, + { + ParentTableName: "foo", + Name: "foo_3", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Size: 4}, + {Name: "author", Type: "text", Size: -1, Collation: cCollation}, + {Name: "content", Type: "text", Default: "''::text", IsNullable: true, Size: -1, Collation: defaultCollation}, + {Name: "genre", Type: "character varying(256)", Size: -1, Collation: defaultCollation}, + {Name: "created_at", Type: "timestamp without time zone", Default: "CURRENT_TIMESTAMP", Size: 8}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some author 3')", + }, + }, + Indexes: []schema.Index{ + { + TableName: "foo", + Name: "foo_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON ONLY public.foo USING btree (author, id)", + }, + { + TableName: "foo", + Name: "some_partitioned_idx", Columns: []string{"author"}, + GetIndexDefStmt: "CREATE INDEX some_partitioned_idx ON ONLY public.foo USING hash (author)", + }, + { + TableName: "foo", + Name: "some_unique_partitioned_idx", Columns: []string{"author", "created_at"}, IsPk: false, IsUnique: true, + GetIndexDefStmt: "CREATE UNIQUE INDEX some_unique_partitioned_idx ON ONLY public.foo USING btree (author, created_at DESC)", + }, + { + TableName: "foo", + Name: "some_invalid_idx", Columns: []string{"author", "genre"}, IsInvalid: true, IsPk: false, IsUnique: false, + GetIndexDefStmt: "CREATE INDEX some_invalid_idx ON ONLY public.foo USING btree (author, genre)", + }, + // foo_1 indexes + { + TableName: "foo_1", + Name: "foo_1_author_idx", Columns: []string{"author"}, ParentIdxName: "some_partitioned_idx", + GetIndexDefStmt: "CREATE INDEX foo_1_author_idx ON public.foo_1 USING hash (author)", + }, + { + TableName: "foo_1", + Name: "foo_1_author_created_at_idx", Columns: []string{"author", "created_at"}, IsPk: false, IsUnique: true, ParentIdxName: "some_unique_partitioned_idx", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_1_author_created_at_idx ON public.foo_1 USING btree (author, created_at DESC)", + }, + { + TableName: "foo_1", + Name: "foo_1_local_idx", Columns: []string{"author", "id"}, IsUnique: true, + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_1_local_idx ON public.foo_1 USING btree (author DESC, id)", + }, + { + TableName: "foo_1", + Name: "foo_1_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_1_pkey", ParentIdxName: "foo_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_1_pkey ON public.foo_1 USING btree (author, id)", + }, + // foo_2 indexes + { + TableName: "foo_2", + Name: "foo_2_author_idx", Columns: []string{"author"}, ParentIdxName: "some_partitioned_idx", + GetIndexDefStmt: "CREATE INDEX foo_2_author_idx ON public.foo_2 USING hash (author)", + }, + { + TableName: "foo_2", + Name: "foo_2_author_created_at_idx", Columns: []string{"author", "created_at"}, IsPk: false, IsUnique: true, ParentIdxName: "some_unique_partitioned_idx", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_2_author_created_at_idx ON public.foo_2 USING btree (author, created_at DESC)", + }, + { + TableName: "foo_2", + Name: "foo_2_local_idx", Columns: []string{"author", "content"}, IsUnique: true, + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_2_local_idx ON public.foo_2 USING btree (author, content)", + }, + { + TableName: "foo_2", + Name: "foo_2_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_2_pkey", ParentIdxName: "foo_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_2_pkey ON public.foo_2 USING btree (author, id)", + }, + // foo_3 indexes + { + TableName: "foo_3", + Name: "foo_3_author_idx", Columns: []string{"author"}, ParentIdxName: "some_partitioned_idx", + GetIndexDefStmt: "CREATE INDEX foo_3_author_idx ON public.foo_3 USING hash (author)", + }, + { + TableName: "foo_3", + Name: "foo_3_author_created_at_idx", Columns: []string{"author", "created_at"}, IsPk: false, IsUnique: true, ParentIdxName: "some_unique_partitioned_idx", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_3_author_created_at_idx ON public.foo_3 USING btree (author, created_at DESC)", + }, + { + TableName: "foo_3", + Name: "foo_3_local_idx", Columns: []string{"author", "created_at"}, IsUnique: true, + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_3_local_idx ON public.foo_3 USING btree (author, created_at)", + }, + { + TableName: "foo_3", + Name: "foo_3_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_3_pkey", ParentIdxName: "foo_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_3_pkey ON public.foo_3 USING btree (author, id)", + }, + }, + + Functions: []schema.Function{ + { + SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"}, + FunctionDef: "CREATE OR REPLACE FUNCTION public.increment_version()\n RETURNS trigger\n LANGUAGE plpgsql\nAS $function$\n\t\t\t\tBEGIN\n\t\t\t\t\tNEW.version = OLD.version + 1;\n\t\t\t\t\tRETURN NEW;\n\t\t\t\tEND;\n\t\t\t$function$\n", + Language: "plpgsql", + }, + }, + Triggers: []schema.Trigger{ + { + EscapedName: "\"some_trigger\"", + OwningTable: schema.SchemaQualifiedName{EscapedName: "\"foo\"", SchemaName: "public"}, + OwningTableUnescapedName: "foo", + Function: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"}, + GetTriggerDefStmt: "CREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()", + }, + { + EscapedName: "\"some_partition_trigger\"", + OwningTable: schema.SchemaQualifiedName{EscapedName: "\"foo_1\"", SchemaName: "public"}, + OwningTableUnescapedName: "foo_1", + Function: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"}, + GetTriggerDefStmt: "CREATE TRIGGER some_partition_trigger BEFORE UPDATE ON public.foo_1 FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()", + }, + }, + }, + }, + { + name: "Partition test local primary key", + ddl: []string{` + CREATE TABLE foo ( + id INTEGER, + author TEXT + ) PARTITION BY LIST (author); + + CREATE TABLE foo_1 PARTITION OF foo( + PRIMARY KEY (author, id) + ) FOR VALUES IN ('some author 1'); + `}, + expectedHash: "6976f3d0ada49b66", + expectedSchema: schema.Schema{ + Name: "public", + Tables: []schema.Table{ + { + Name: "foo", + Columns: []schema.Column{ + {Name: "id", Type: "integer", IsNullable: true, Size: 4}, + {Name: "author", Type: "text", IsNullable: true, Size: -1, Collation: defaultCollation}, + }, + CheckConstraints: nil, + PartitionKeyDef: "LIST (author)", + }, + { + ParentTableName: "foo", + Name: "foo_1", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Size: 4}, + {Name: "author", Type: "text", Size: -1, Collation: defaultCollation}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some author 1')", + }, + }, + Indexes: []schema.Index{ + { + TableName: "foo_1", + Name: "foo_1_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_1_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_1_pkey ON public.foo_1 USING btree (author, id)", + }, + }, + }, + }, + { + name: "Common Data Types", + ddl: []string{` + CREATE TABLE foo ( + "varchar" VARCHAR(128) NOT NULL DEFAULT '', + "text" TEXT NOT NULL DEFAULT '', + "bool" BOOLEAN NOT NULL DEFAULT False, + "blob" BYTEA NOT NULL DEFAULT '', + "smallint" SMALLINT NOT NULL DEFAULT 0, + "real" REAL NOT NULL DEFAULT 0.0, + "double_precision" DOUBLE PRECISION NOT NULL DEFAULT 0.0, + "integer" INTEGER NOT NULL DEFAULT 0, + "big_integer" BIGINT NOT NULL DEFAULT 0, + "decimal" DECIMAL(65, 10) NOT NULL DEFAULT 0.0 + ); + `}, + expectedHash: "2181b2da75bb74f7", + expectedSchema: schema.Schema{ + Name: "public", + Tables: []schema.Table{ + { + Name: "foo", + Columns: []schema.Column{ + {Name: "varchar", Type: "character varying(128)", Default: "''::character varying", Size: -1, Collation: defaultCollation}, + {Name: "text", Type: "text", Default: "''::text", Size: -1, Collation: defaultCollation}, + {Name: "bool", Type: "boolean", Default: "false", Size: 1}, + {Name: "blob", Type: "bytea", Default: `'\x'::bytea`, Size: -1}, + {Name: "smallint", Type: "smallint", Default: "0", Size: 2}, + {Name: "real", Type: "real", Default: "0.0", Size: 4}, + {Name: "double_precision", Type: "double precision", Default: "0.0", Size: 8}, + {Name: "integer", Type: "integer", Default: "0", Size: 4}, + {Name: "big_integer", Type: "bigint", Default: "0", Size: 8}, + {Name: "decimal", Type: "numeric(65,10)", Default: "0.0", Size: -1}, + }, + CheckConstraints: nil, + }, + }, + }, + }, + { + name: "Multi-Table", + ddl: []string{` + CREATE TABLE foo ( + id INTEGER PRIMARY KEY CHECK (id > 0) NO INHERIT, + content TEXT DEFAULT 'some default' + ); + CREATE INDEX foo_idx ON foo(id, content); + CREATE TABLE bar( + id INTEGER PRIMARY KEY CHECK (id > 0), + content TEXT NOT NULL + ); + CREATE INDEX bar_idx ON bar(content, id); + CREATE TABLE foobar( + id INTEGER PRIMARY KEY, + content BIGINT NOT NULL + ); + ALTER TABLE foobar ADD CONSTRAINT foobar_id_check CHECK (id > 0) NOT VALID; + CREATE UNIQUE INDEX foobar_idx ON foobar(content); + `}, + expectedHash: "6518bbfe220d4f16", + expectedSchema: schema.Schema{ + Name: "public", + Tables: []schema.Table{ + { + Name: "foo", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Size: 4}, + {Name: "content", Type: "text", IsNullable: true, Default: "'some default'::text", Size: -1, Collation: defaultCollation}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "foo_id_check", Expression: "(id > 0)", IsValid: true}, + }, + }, + { + Name: "bar", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Size: 4}, + {Name: "content", Type: "text", Size: -1, Collation: defaultCollation}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "bar_id_check", Expression: "(id > 0)", IsValid: true, IsInheritable: true}, + }, + }, + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Size: 4}, + {Name: "content", Type: "bigint", Size: 8}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "foobar_id_check", Expression: "(id > 0)", IsInheritable: true}, + }, + }, + }, + Indexes: []schema.Index{ + // foo indexes + { + TableName: "foo", + Name: "foo_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON public.foo USING btree (id)", + }, + { + TableName: "foo", + Name: "foo_idx", Columns: []string{"id", "content"}, + GetIndexDefStmt: "CREATE INDEX foo_idx ON public.foo USING btree (id, content)", + }, + // bar indexes + { + TableName: "bar", + Name: "bar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "bar_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX bar_pkey ON public.bar USING btree (id)", + }, + { + TableName: "bar", + Name: "bar_idx", Columns: []string{"content", "id"}, + GetIndexDefStmt: "CREATE INDEX bar_idx ON public.bar USING btree (content, id)", + }, + // foobar indexes + { + TableName: "foobar", + Name: "foobar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_pkey ON public.foobar USING btree (id)", + }, + { + TableName: "foobar", + Name: "foobar_idx", Columns: []string{"content"}, IsUnique: true, + GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_idx ON public.foobar USING btree (content)", + }, + }, + }, + }, + { + name: "Multi-Schema", + ddl: []string{` + CREATE TABLE foo ( + id INTEGER PRIMARY KEY CHECK (id > 0), + version INTEGER NOT NULL DEFAULT 0, + content TEXT + ); + + CREATE FUNCTION dup(in int, out f1 int, out f2 text) + AS $$ SELECT $1, CAST($1 AS text) || ' is text' $$ + LANGUAGE SQL; + + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_trigger + BEFORE UPDATE ON foo + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + + CREATE SCHEMA test; + + CREATE TABLE test.foo( + test_schema_id INTEGER PRIMARY KEY, + test_schema_version INTEGER NOT NULL DEFAULT 0, + test_schema_content TEXT CHECK (LENGTH(test_schema_content) > 0) + ); + + CREATE FUNCTION test.dup(in int, out f1 int, out f2 text) + AS $$ SELECT $1, ' is int' $$ + LANGUAGE SQL; + + CREATE FUNCTION test.increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_trigger + BEFORE UPDATE ON test.foo + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + + CREATE COLLATION test."some collation" (locale = 'en_US'); + + CREATE TABLE bar ( + id INTEGER CHECK (id > 0), + author TEXT COLLATE test."some collation", + PRIMARY KEY (author, id) + ) PARTITION BY LIST (author); + CREATE INDEX some_partitioned_idx ON bar(author, id); + CREATE TABLE bar_1 PARTITION OF bar FOR VALUES IN ('some author 1'); + + CREATE TABLE test.bar ( + test_id INTEGER CHECK (test_id > 0), + test_author TEXT, + PRIMARY KEY (test_author, test_id) + ) PARTITION BY LIST (test_author); + CREATE INDEX some_partitioned_idx ON test.bar(test_author, test_id); + CREATE TABLE test.bar_1 PARTITION OF test.bar FOR VALUES IN ('some author 1'); + + -- create a trigger on the original schema using a function from the other schema + CREATE TRIGGER some_trigger_using_other_schema_function + BEFORE UPDATE ON foo + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE test.increment_version(); + `}, + expectedHash: "a6a845ad846dc362", + expectedSchema: schema.Schema{ + Name: "public", + Tables: []schema.Table{ + { + Name: "foo", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Size: 4}, + {Name: "version", Type: "integer", Default: "0", Size: 4}, + {Name: "content", Type: "text", IsNullable: true, Size: -1, Collation: defaultCollation}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "foo_id_check", Expression: "(id > 0)", IsValid: true, IsInheritable: true}, + }, + }, + { + Name: "bar", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Default: "", Size: 4}, + {Name: "author", Type: "text", Default: "", Size: -1, Collation: schema.SchemaQualifiedName{SchemaName: "test", EscapedName: `"some collation"`}}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "bar_id_check", Expression: "(id > 0)", IsValid: true, IsInheritable: true}, + }, + PartitionKeyDef: "LIST (author)", + }, + { + ParentTableName: "bar", + Name: "bar_1", + Columns: []schema.Column{ + {Name: "id", Type: "integer", Default: "", Size: 4}, + {Name: "author", Type: "text", Default: "", Size: -1, Collation: schema.SchemaQualifiedName{SchemaName: "test", EscapedName: `"some collation"`}}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some author 1')", + }, + }, + Indexes: []schema.Index{ + // foo indexes + { + TableName: "foo", + Name: "foo_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON public.foo USING btree (id)", + }, + // bar indexes + { + TableName: "bar", + Name: "bar_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "bar_pkey", ParentIdxName: "", + GetIndexDefStmt: "CREATE UNIQUE INDEX bar_pkey ON ONLY public.bar USING btree (author, id)", + }, + { + TableName: "bar", + Name: "some_partitioned_idx", Columns: []string{"author", "id"}, IsPk: false, IsUnique: false, ConstraintName: "", ParentIdxName: "", + GetIndexDefStmt: "CREATE INDEX some_partitioned_idx ON ONLY public.bar USING btree (author, id)", + }, + // bar_1 indexes + { + TableName: "bar_1", + Name: "bar_1_author_id_idx", Columns: []string{"author", "id"}, IsPk: false, IsUnique: false, ConstraintName: "", ParentIdxName: "some_partitioned_idx", + GetIndexDefStmt: "CREATE INDEX bar_1_author_id_idx ON public.bar_1 USING btree (author, id)", + }, + { + TableName: "bar_1", + Name: "bar_1_pkey", Columns: []string{"author", "id"}, IsPk: true, IsUnique: true, ConstraintName: "bar_1_pkey", ParentIdxName: "bar_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX bar_1_pkey ON public.bar_1 USING btree (author, id)", + }, + }, + Functions: []schema.Function{ + { + SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"dup\"(integer, OUT f1 integer, OUT f2 text)", SchemaName: "public"}, + FunctionDef: "CREATE OR REPLACE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text)\n RETURNS record\n LANGUAGE sql\nAS $function$ SELECT $1, CAST($1 AS text) || ' is text' $function$\n", + Language: "sql", + }, + { + SchemaQualifiedName: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"}, + FunctionDef: "CREATE OR REPLACE FUNCTION public.increment_version()\n RETURNS trigger\n LANGUAGE plpgsql\nAS $function$\n\t\t\t\tBEGIN\n\t\t\t\t\tNEW.version = OLD.version + 1;\n\t\t\t\t\tRETURN NEW;\n\t\t\t\tEND;\n\t\t\t$function$\n", + Language: "plpgsql", + }, + }, + Triggers: []schema.Trigger{ + { + EscapedName: "\"some_trigger\"", + OwningTable: schema.SchemaQualifiedName{EscapedName: "\"foo\"", SchemaName: "public"}, + OwningTableUnescapedName: "foo", + Function: schema.SchemaQualifiedName{EscapedName: "\"increment_version\"()", SchemaName: "public"}, + GetTriggerDefStmt: "CREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()", + }, + }, + }, + }, + { + name: "Empty Schema", + ddl: nil, + expectedHash: "660be155e4c39f8b", + expectedSchema: schema.Schema{ + Name: "public", + Tables: nil, + }, + }, + { + name: "No Indexes or constraints", + ddl: []string{` + CREATE TABLE foo ( + value TEXT + ); + `}, + expectedHash: "9db57cf969f0a509", + expectedSchema: schema.Schema{ + Name: "public", + Tables: []schema.Table{ + { + Name: "foo", + Columns: []schema.Column{ + {Name: "value", Type: "text", IsNullable: true, Size: -1, Collation: defaultCollation}, + }, + CheckConstraints: nil, + }, + }, + }, + }, + } +) + +func TestSchemaTestCases(t *testing.T) { + engine, err := pgengine.StartEngine() + require.NoError(t, err) + defer engine.Close() + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + db, err := engine.CreateDatabase() + require.NoError(t, err) + conn, err := sql.Open("pgx", db.GetDSN()) + require.NoError(t, err) + + for _, stmt := range testCase.ddl { + _, err := conn.Exec(stmt) + require.NoError(t, err) + } + + fetchedSchema, err := schema.GetPublicSchema(context.TODO(), conn) + if testCase.expectedErrIs != nil { + require.ErrorIs(t, err, testCase.expectedErrIs) + return + } else { + require.NoError(t, err) + } + + expectedNormalized := testCase.expectedSchema.Normalize() + fetchedNormalized := fetchedSchema.Normalize() + assert.Equal(t, expectedNormalized, fetchedNormalized, "expected=\n%# v \n fetched=%# v\n", pretty.Formatter(expectedNormalized), pretty.Formatter(fetchedNormalized)) + + fetchedSchemaHash, err := fetchedSchema.Hash() + require.NoError(t, err) + expectedSchemaHash, err := testCase.expectedSchema.Hash() + require.NoError(t, err) + assert.Equal(t, testCase.expectedHash, fetchedSchemaHash) + // same schemas should have the same hashes + assert.Equal(t, expectedSchemaHash, fetchedSchemaHash, "hash of expected schema should match fetched hash") + + require.NoError(t, conn.Close()) + require.NoError(t, db.DropDB()) + }) + } +} + +func TestIdxDefStmtToCreateIdxConcurrently(t *testing.T) { + for _, tc := range []struct { + name string + defStmt string + out string + expectErr bool + }{ + { + name: "simple index", + defStmt: `CREATE INDEX foobar ON public.foobar USING btree (foo)`, + out: `CREATE INDEX CONCURRENTLY foobar ON public.foobar USING btree (foo)`, + }, + { + name: "unique index", + defStmt: `CREATE UNIQUE INDEX foobar ON public.foobar USING btree (foo)`, + out: `CREATE UNIQUE INDEX CONCURRENTLY foobar ON public.foobar USING btree (foo)`, + }, + { + name: "malicious name index", + defStmt: `CREATE UNIQUE INDEX "CREATE INDEX ON" ON public.foobar USING btree (foo)`, + out: `CREATE UNIQUE INDEX CONCURRENTLY "CREATE INDEX ON" ON public.foobar USING btree (foo)`, + }, + { + name: "case sensitive", + defStmt: `CREATE uNIQUE INDEX foobar ON public.foobar USING btree (foo)`, + expectErr: true, + }, + { + name: "errors with random start character", + defStmt: `ALTER TABLE CREATE UNIQUE INDEX foobar ON public.foobar USING btree (foo)`, + expectErr: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + out, err := schema.GetIndexDefStatement(tc.defStmt).ToCreateIndexConcurrently() + if tc.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tc.out, out) + } + }) + } +} + +func TestTriggerDefStmtToCreateOrReplace(t *testing.T) { + for _, tc := range []struct { + name string + defStmt string + out string + expectErr bool + }{ + { + name: "simple trigger", + defStmt: "CREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()", + out: "CREATE OR REPLACE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()", + }, + { + name: "malicious name trigger", + defStmt: `CREATE TRIGGER "CREATE TRIGGER" BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()`, + out: `CREATE OR REPLACE TRIGGER "CREATE TRIGGER" BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()`, + }, + { + name: "case sensitive", + defStmt: "cREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()", + expectErr: true, + }, + { + name: "errors with random start character", + defStmt: "ALTER TRIGGER CREATE TRIGGER some_trigger BEFORE UPDATE ON public.foo FOR EACH ROW WHEN ((old.* IS DISTINCT FROM new.*)) EXECUTE FUNCTION increment_version()", + expectErr: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + out, err := schema.GetTriggerDefStatement(tc.defStmt).ToCreateOrReplace() + if tc.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tc.out, out) + } + }) + } +} diff --git a/pkg/diff/diff.go b/pkg/diff/diff.go new file mode 100644 index 0000000..be9cfb5 --- /dev/null +++ b/pkg/diff/diff.go @@ -0,0 +1,323 @@ +package diff + +import ( + "fmt" + "sort" + + "github.com/stripe/pg-schema-diff/internal/graph" + "github.com/stripe/pg-schema-diff/internal/schema" +) + +var ErrNotImplemented = fmt.Errorf("not implemented") +var errDuplicateIdentifier = fmt.Errorf("duplicate identifier") + +type diffType string + +const ( + diffTypeDelete diffType = "DELETE" + diffTypeAddAlter diffType = "ADDALTER" +) + +type ( + diff[S schema.Object] interface { + GetOld() S + GetNew() S + } + + // sqlGenerator is used to generate SQL that resolves diffs between lists + sqlGenerator[S schema.Object, Diff diff[S]] interface { + Add(S) ([]Statement, error) + Delete(S) ([]Statement, error) + // Alter generates the statements required to resolve the schema object to its new state using the + // provided diff. Alter, e.g., with a table, might produce add/delete statements + Alter(Diff) ([]Statement, error) + } + + // dependency indicates an edge between the SQL to resolve a diff for a source schema object and the SQL to resolve + // the diff of a target schema object + // + // Most SchemaObjects will have two nodes in the SQL graph: a node for delete SQL and a node for add/alter SQL. + // These nodes will almost always be present in the sqlGraph even if the schema object is not being deleted (or added/altered). + // If a node is present for a schema object where the "diffType" is NOT occurring, it will just be a no-op (no SQl statements) + dependency struct { + sourceObjId string + sourceType diffType + + targetObjId string + targetType diffType + } +) + +type dependencyBuilder struct { + valObjId string + valType diffType +} + +func mustRun(schemaObjId string, schemaDiffType diffType) dependencyBuilder { + return dependencyBuilder{ + valObjId: schemaObjId, + valType: schemaDiffType, + } +} + +func (d dependencyBuilder) before(valObjId string, valType diffType) dependency { + return dependency{ + sourceType: d.valType, + sourceObjId: d.valObjId, + + targetType: valType, + targetObjId: valObjId, + } +} + +func (d dependencyBuilder) after(valObjId string, valType diffType) dependency { + return dependency{ + sourceObjId: valObjId, + sourceType: valType, + + targetObjId: d.valObjId, + targetType: d.valType, + } +} + +// sqlVertexGenerator is used to generate SQL statements for schema objects that have dependency webs +// with other schema objects. The schema object represents a vertex in the graph. +type sqlVertexGenerator[S schema.Object, Diff diff[S]] interface { + sqlGenerator[S, Diff] + // GetSQLVertexId gets the canonical vertex id to represent the schema object + GetSQLVertexId(S) string + + // GetAddAlterDependencies gets the dependencies of the SQL generated to resolve the AddAlter diff for the + // schema objects. Dependencies can be formed on any other nodes in the SQL graph, even if the node has + // no statements. If the diff is just an add, then old will be the zero value + // + // These dependencies can also be built in reverse: the SQL returned by the sqlVertexGenerator to resolve the + // diff for the object must always be run before the SQL required to resolve another SQL vertex diff + GetAddAlterDependencies(new S, old S) []dependency + + // GetDeleteDependencies is the same as above but for deletes. + // Invariant to maintain: + // - If an object X depends on the delete for an object Y (generated by the sqlVertexGenerator), immediately after the + // the (Y, diffTypeDelete) sqlVertex's SQL is run, Y must no longer be present in the schema; either the + // (Y, diffTypeDelete) statements deleted Y or something that vertex depended on deleted Y. In other words, if a + // delete is cascaded by another delete (e.g., index dropped by table drop) and the index SQL is empty, + // the index delete vertex must still have dependency from itself to the object from which the delete cascades down from + GetDeleteDependencies(S) []dependency +} + +type ( + // listDiff represents the differences between two lists. + listDiff[S schema.Object, Diff diff[S]] struct { + adds []S + deletes []S + // alters contains the diffs of any objects that persisted between two schemas + alters []Diff + } + + sqlGroupedByEffect[S schema.Object, Diff diff[S]] struct { + Adds []Statement + Deletes []Statement + // Alters might contain adds and deletes. For example, a set of alters for a table might add indexes. + Alters []Statement + } +) + +func (ld listDiff[S, D]) isEmpty() bool { + return len(ld.adds) == 0 || len(ld.alters) == 0 || len(ld.deletes) == 0 +} + +func (ld listDiff[S, D]) resolveToSQLGroupedByEffect(sqlGenerator sqlGenerator[S, D]) (sqlGroupedByEffect[S, D], error) { + var adds, deletes, alters []Statement + + for _, a := range ld.adds { + statements, err := sqlGenerator.Add(a) + if err != nil { + return sqlGroupedByEffect[S, D]{}, fmt.Errorf("generating SQL for add %s: %w", a.GetName(), err) + } + adds = append(adds, statements...) + } + for _, d := range ld.deletes { + statements, err := sqlGenerator.Delete(d) + if err != nil { + return sqlGroupedByEffect[S, D]{}, fmt.Errorf("generating SQL for delete %s: %w", d.GetName(), err) + } + deletes = append(deletes, statements...) + } + for _, a := range ld.alters { + statements, err := sqlGenerator.Alter(a) + if err != nil { + return sqlGroupedByEffect[S, D]{}, fmt.Errorf("generating SQL for diff %+v: %w", a, err) + } + alters = append(alters, statements...) + } + + return sqlGroupedByEffect[S, D]{ + Adds: adds, + Deletes: deletes, + Alters: alters, + }, nil +} + +func (ld listDiff[S, D]) resolveToSQLGraph(generator sqlVertexGenerator[S, D]) (*sqlGraph, error) { + graph := graph.NewGraph[sqlVertex]() + + for _, a := range ld.adds { + statements, err := generator.Add(a) + if err != nil { + return nil, fmt.Errorf("generating SQL for add %s: %w", a.GetName(), err) + } + + if err := addSQLVertexToGraph(graph, sqlVertex{ + ObjId: generator.GetSQLVertexId(a), + Statements: statements, + DiffType: diffTypeAddAlter, + }, generator.GetAddAlterDependencies(a, *new(S))); err != nil { + return nil, fmt.Errorf("adding SQL Vertex for add %s: %w", a.GetName(), err) + } + } + + for _, a := range ld.alters { + statements, err := generator.Alter(a) + if err != nil { + return nil, fmt.Errorf("generating SQL for diff %+v: %w", a, err) + } + + vertexId := generator.GetSQLVertexId(a.GetOld()) + vertexIdAfterAlter := generator.GetSQLVertexId(a.GetNew()) + if vertexIdAfterAlter != vertexId { + return nil, fmt.Errorf("an alter lead to a node with a different id: old=%s, new=%s", vertexId, vertexIdAfterAlter) + } + + if err := addSQLVertexToGraph(graph, sqlVertex{ + ObjId: vertexId, + Statements: statements, + DiffType: diffTypeAddAlter, + }, generator.GetAddAlterDependencies(a.GetNew(), a.GetNew())); err != nil { + return nil, fmt.Errorf("adding SQL Vertex for alter %s: %w", a.GetOld().GetName(), err) + } + } + + for _, d := range ld.deletes { + statements, err := generator.Delete(d) + if err != nil { + return nil, fmt.Errorf("generating SQL for delete %s: %w", d.GetName(), err) + } + + if err := addSQLVertexToGraph(graph, sqlVertex{ + ObjId: generator.GetSQLVertexId(d), + Statements: statements, + DiffType: diffTypeDelete, + }, generator.GetDeleteDependencies(d)); err != nil { + return nil, fmt.Errorf("adding SQL Vertex for delete %s: %w", d.GetName(), err) + } + } + + return (*sqlGraph)(graph), nil +} + +func addSQLVertexToGraph(graph *graph.Graph[sqlVertex], vertex sqlVertex, dependencies []dependency) error { + // It's possible the node already exists. merge it if it does + if graph.HasVertexWithId(vertex.GetId()) { + vertex = mergeSQLVertices(graph.GetVertex(vertex.GetId()), vertex) + } + graph.AddVertex(vertex) + for _, dep := range dependencies { + if err := addDependency(graph, dep); err != nil { + return fmt.Errorf("adding dependencies for %s: %w", vertex.GetId(), err) + } + } + return nil +} + +func addDependency(graph *graph.Graph[sqlVertex], dep dependency) error { + sourceVertex := sqlVertex{ + ObjId: dep.sourceObjId, + DiffType: dep.sourceType, + Statements: nil, + } + targetVertex := sqlVertex{ + ObjId: dep.targetObjId, + DiffType: dep.targetType, + Statements: nil, + } + + // To maintain the correctness of the graph, we will add a dummy vertex for the missing dependencies + addVertexIfNotExists(graph, sourceVertex) + addVertexIfNotExists(graph, targetVertex) + + if err := graph.AddEdge(sourceVertex.GetId(), targetVertex.GetId()); err != nil { + return fmt.Errorf("adding edge from %s to %s: %w", sourceVertex.GetId(), targetVertex.GetId(), err) + } + + return nil +} + +func addVertexIfNotExists(graph *graph.Graph[sqlVertex], vertex sqlVertex) { + if !graph.HasVertexWithId(vertex.GetId()) { + graph.AddVertex(vertex) + } +} + +type schemaObjectEntry[S schema.Object] struct { + index int // index is the index the schema object in the list + obj S +} + +// diffLists diffs two lists of schema objects using name. +// If an object is present in both lists, it will use buildDiff function to build the diffs between the two objects. If +// build diff returns as requiresRecreation, then the old schema object will be deleted and the new one will be added +// +// The List will outputted in a deterministic order by schema object name, which is important for tests +func diffLists[S schema.Object, Diff diff[S]]( + oldSchemaObjs, newSchemaObjs []S, + buildDiff func(old, new S, oldIndex, newIndex int) (diff Diff, requiresRecreation bool, error error), +) (listDiff[S, Diff], error) { + nameToOld := make(map[string]schemaObjectEntry[S]) + for oldIndex, oldSchemaObject := range oldSchemaObjs { + if _, nameAlreadyTaken := nameToOld[oldSchemaObject.GetName()]; nameAlreadyTaken { + return listDiff[S, Diff]{}, fmt.Errorf("multiple objects have identifier %s: %w", oldSchemaObject.GetName(), errDuplicateIdentifier) + } + // store the old schema object and its index. if an alteration, the index might be used in the diff, e.g., for columns + nameToOld[oldSchemaObject.GetName()] = schemaObjectEntry[S]{ + obj: oldSchemaObject, + index: oldIndex, + } + } + + var adds []S + var alters []Diff + var deletes []S + for newIndex, newSchemaObj := range newSchemaObjs { + if oldSchemaObjAndIndex, hasOldSchemaObj := nameToOld[newSchemaObj.GetName()]; !hasOldSchemaObj { + adds = append(adds, newSchemaObj) + } else { + delete(nameToOld, newSchemaObj.GetName()) + + diff, requiresRecreation, err := buildDiff(oldSchemaObjAndIndex.obj, newSchemaObj, oldSchemaObjAndIndex.index, newIndex) + if err != nil { + return listDiff[S, Diff]{}, fmt.Errorf("diffing for %s: %w", newSchemaObj.GetName(), err) + } + if requiresRecreation { + deletes = append(deletes, oldSchemaObjAndIndex.obj) + adds = append(adds, newSchemaObj) + } else { + alters = append(alters, diff) + } + } + } + + // Remaining schema objects in nameToOld have been deleted + for _, d := range nameToOld { + deletes = append(deletes, d.obj) + } + // Iterating through a map is non-deterministic in go, so we'll sort the deletes by schema object name + sort.Slice(deletes, func(i, j int) bool { + return deletes[i].GetName() < deletes[j].GetName() + }) + + return listDiff[S, Diff]{ + adds: adds, + deletes: deletes, + alters: alters, + }, nil +} diff --git a/pkg/diff/plan.go b/pkg/diff/plan.go new file mode 100644 index 0000000..589903d --- /dev/null +++ b/pkg/diff/plan.go @@ -0,0 +1,69 @@ +package diff + +import ( + "fmt" + "regexp" + "time" +) + +type MigrationHazardType = string + +const ( + MigrationHazardTypeAcquiresAccessExclusiveLock MigrationHazardType = "ACQUIRES_ACCESS_EXCLUSIVE_LOCK" + MigrationHazardTypeAcquiresShareLock MigrationHazardType = "ACQUIRES_SHARE_LOCK" + MigrationHazardTypeDeletesData MigrationHazardType = "DELETES_DATA" + MigrationHazardTypeHasUntrackableDependencies MigrationHazardType = "HAS_UNTRACKABLE_DEPENDENCIES" + MigrationHazardTypeIndexBuild MigrationHazardType = "INDEX_BUILD" + MigrationHazardTypeIndexDropped MigrationHazardType = "INDEX_DROPPED" + MigrationHazardTypeImpactsDatabasePerformance MigrationHazardType = "IMPACTS_DATABASE_PERFORMANCE" + MigrationHazardTypeIsUserGenerated MigrationHazardType = "IS_USER_GENERATED" +) + +type MigrationHazard struct { + Type MigrationHazardType + Message string +} + +func (p MigrationHazard) String() string { + return fmt.Sprintf("%s: %s", p.Type, p.Message) +} + +type Statement struct { + DDL string + Timeout time.Duration + Hazards []MigrationHazard +} + +func (s Statement) ToSQL() string { + return s.DDL + ";" +} + +type Plan struct { + Statements []Statement + CurrentSchemaHash string +} + +func (p Plan) ApplyStatementTimeoutModifier(regex *regexp.Regexp, timeout time.Duration) Plan { + var modifiedStmts []Statement + for _, stmt := range p.Statements { + if regex.MatchString(stmt.DDL) { + stmt.Timeout = timeout + } + modifiedStmts = append(modifiedStmts, stmt) + } + p.Statements = modifiedStmts + return p +} + +func (p Plan) InsertStatement(index int, statement Statement) (Plan, error) { + if index < 0 || index > len(p.Statements) { + return Plan{}, fmt.Errorf("index must be >= 0 and <= %d", len(p.Statements)) + } + if index == len(p.Statements) { + p.Statements = append(p.Statements, statement) + return p, nil + } + p.Statements = append(p.Statements[:index+1], p.Statements[index:]...) + p.Statements[index] = statement + return p, nil +} diff --git a/pkg/diff/plan_generator.go b/pkg/diff/plan_generator.go new file mode 100644 index 0000000..7025d95 --- /dev/null +++ b/pkg/diff/plan_generator.go @@ -0,0 +1,244 @@ +package diff + +import ( + "context" + "database/sql" + "fmt" + "strings" + + _ "github.com/jackc/pgx/v4/stdlib" + "github.com/kr/pretty" + "github.com/stripe/pg-schema-diff/internal/schema" + + "github.com/stripe/pg-schema-diff/internal/queries" + "github.com/stripe/pg-schema-diff/pkg/log" + "github.com/stripe/pg-schema-diff/pkg/tempdb" +) + +type ( + planOptions struct { + dataPackNewTables bool + ignoreChangesToColOrder bool + logger log.Logger + validatePlan bool + } + + PlanOpt func(opts *planOptions) +) + +// WithDataPackNewTables configures the plan generation such that it packs the columns in the new tables to minimize +// padding. It will help minimize the storage used by the tables +func WithDataPackNewTables() PlanOpt { + return func(opts *planOptions) { + opts.dataPackNewTables = true + } +} + +// WithIgnoreChangesToColOrder configures the plan generation to ignore any changes to the ordering of columns in +// existing tables. You will most likely want this enabled +func WithIgnoreChangesToColOrder() PlanOpt { + return func(opts *planOptions) { + opts.ignoreChangesToColOrder = true + } +} + +func WithDoNotValidatePlan() PlanOpt { + return func(opts *planOptions) { + opts.validatePlan = false + } +} + +func WithLogger(logger log.Logger) PlanOpt { + return func(opts *planOptions) { + opts.logger = logger + } +} + +func GeneratePlan(ctx context.Context, conn *sql.Conn, tempDbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) { + planOptions := &planOptions{ + validatePlan: true, + logger: log.SimpleLogger(), + } + for _, opt := range opts { + opt(planOptions) + } + + currentSchema, err := schema.GetPublicSchema(ctx, conn) + if err != nil { + return Plan{}, fmt.Errorf("getting current schema: %w", err) + } + newSchema, err := deriveSchemaFromDDLOnTempDb(ctx, planOptions.logger, tempDbFactory, newDDL) + if err != nil { + return Plan{}, fmt.Errorf("getting new schema: %w", err) + } + + statements, err := generateMigrationStatements(currentSchema, newSchema, planOptions) + if err != nil { + return Plan{}, fmt.Errorf("generating plan statements: %w", err) + } + + hash, err := currentSchema.Hash() + if err != nil { + return Plan{}, fmt.Errorf("generating current schema hash: %w", err) + } + + plan := Plan{ + Statements: statements, + CurrentSchemaHash: hash, + } + + if planOptions.validatePlan { + if err := assertValidPlan(ctx, tempDbFactory, currentSchema, newSchema, plan, planOptions); err != nil { + return Plan{}, fmt.Errorf("validating migration plan: %w \n%# v", err, pretty.Formatter(plan)) + } + } + + return plan, nil +} + +func deriveSchemaFromDDLOnTempDb(ctx context.Context, logger log.Logger, tempDbFactory tempdb.Factory, ddl []string) (schema.Schema, error) { + tempDb, dropTempDb, err := tempDbFactory.Create(ctx) + if err != nil { + return schema.Schema{}, fmt.Errorf("creating temp database: %w", err) + } + defer func(drop tempdb.Dropper) { + if err := drop(ctx); err != nil { + logger.Errorf("an error occurred while dropping the temp database: %s", err) + } + }(dropTempDb) + + for _, stmt := range ddl { + if _, err := tempDb.ExecContext(ctx, stmt); err != nil { + return schema.Schema{}, fmt.Errorf("running DDL: %w", err) + } + } + + return schema.GetPublicSchema(ctx, tempDb) +} + +func generateMigrationStatements(oldSchema, newSchema schema.Schema, planOptions *planOptions) ([]Statement, error) { + diff, _, err := buildSchemaDiff(oldSchema, newSchema) + if err != nil { + return nil, err + } + + if planOptions.dataPackNewTables { + // Instead of enabling ignoreChangesToColOrder by default, force the user to enable ignoreChangesToColOrder. + // This ensures the user knows what's going on behind-the-scenes + if !planOptions.ignoreChangesToColOrder { + return nil, fmt.Errorf("cannot data pack new tables without also ignoring changes to column order") + } + diff = dataPackNewTables(diff) + } + if planOptions.ignoreChangesToColOrder { + diff = removeChangesToColumnOrdering(diff) + } + + statements, err := diff.resolveToSQL() + if err != nil { + return nil, fmt.Errorf("generating migration statements: %w", err) + } + return statements, nil +} + +func assertValidPlan(ctx context.Context, + tempDbFactory tempdb.Factory, + currentSchema, newSchema schema.Schema, + plan Plan, + planOptions *planOptions, +) error { + tempDb, dropTempDb, err := tempDbFactory.Create(ctx) + if err != nil { + return err + } + defer func(drop tempdb.Dropper) { + if err := drop(ctx); err != nil { + planOptions.logger.Errorf("an error occurred while dropping the temp database: %s", err) + } + }(dropTempDb) + + tempDbConn, err := tempDb.Conn(ctx) + if err != nil { + return fmt.Errorf("opening database connection: %w", err) + } + defer tempDbConn.Close() + + if err := setSchemaForEmptyDatabase(ctx, tempDbConn, currentSchema); err != nil { + return fmt.Errorf("inserting schema in temporary database: %w", err) + } + + if err := executeStatements(ctx, tempDbConn, plan.Statements); err != nil { + return fmt.Errorf("running migration plan: %w", err) + } + + migratedSchema, err := schema.GetPublicSchema(ctx, tempDbConn) + if err != nil { + return fmt.Errorf("fetching schema from migrated database: %w", err) + } + + return assertMigratedSchemaMatchesTarget(migratedSchema, newSchema, planOptions) +} + +func setSchemaForEmptyDatabase(ctx context.Context, conn *sql.Conn, dbSchema schema.Schema) error { + // We can't create invalid indexes. We'll mark them valid in the schema, which should be functionally + // equivalent for the sake of DDL and other statements. + // + // Make a new array, so we don't mutate the underlying array of the original schema. Ideally, we have a clone function + // in the future + var validIndexes []schema.Index + for _, idx := range dbSchema.Indexes { + idx.IsInvalid = false + validIndexes = append(validIndexes, idx) + } + dbSchema.Indexes = validIndexes + + if statements, err := generateMigrationStatements(schema.Schema{ + Name: "public", + Tables: nil, + Indexes: nil, + Functions: nil, + Triggers: nil, + }, dbSchema, &planOptions{}); err != nil { + return fmt.Errorf("building schema diff: %w", err) + } else { + return executeStatements(ctx, conn, statements) + } +} + +func assertMigratedSchemaMatchesTarget(migratedSchema, targetSchema schema.Schema, planOptions *planOptions) error { + toTargetSchemaStmts, err := generateMigrationStatements(migratedSchema, targetSchema, planOptions) + if err != nil { + return fmt.Errorf("building schema diff between migrated database and new schema: %w", err) + } + + if len(toTargetSchemaStmts) > 0 { + var stmtsStrs []string + for _, stmt := range toTargetSchemaStmts { + stmtsStrs = append(stmtsStrs, stmt.DDL) + } + return fmt.Errorf("diff detected:\n%s", strings.Join(stmtsStrs, "\n")) + } + + return nil +} + +// executeStatements executes the statements using the sql connection. It will modify the session-level +// statement timeout of the underlying connection. +func executeStatements(ctx context.Context, conn queries.DBTX, statements []Statement) error { + // Due to the way *sql.Db works, when a statement_timeout is set for the session, it will NOT reset + // by default when it's returned to the pool. + // + // We can't set the timeout at the TRANSACTION-level (for each transaction) because `ADD INDEX CONCURRENTLY` + // must be executed within its own transaction block. Postgres will error if you try to set a TRANSACTION-level + // timeout for it. SESSION-level statement_timeouts are respected by `ADD INDEX CONCURRENTLY` + for _, stmt := range statements { + if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION statement_timeout = %d", stmt.Timeout.Milliseconds())); err != nil { + return fmt.Errorf("setting statement timeout: %w", err) + } + if _, err := conn.ExecContext(ctx, stmt.ToSQL()); err != nil { + // could the migration statement contain sensitive information? + return fmt.Errorf("executing migration statement: %s: %w", stmt, err) + } + } + return nil +} diff --git a/pkg/diff/plan_generator_test.go b/pkg/diff/plan_generator_test.go new file mode 100644 index 0000000..c29b332 --- /dev/null +++ b/pkg/diff/plan_generator_test.go @@ -0,0 +1,125 @@ +package diff_test + +import ( + "context" + "database/sql" + "io" + "testing" + + _ "github.com/jackc/pgx/v4/stdlib" + "github.com/stretchr/testify/suite" + "github.com/stripe/pg-schema-diff/internal/pgengine" + "github.com/stripe/pg-schema-diff/pkg/diff" + + "github.com/stripe/pg-schema-diff/pkg/tempdb" +) + +type simpleMigratorTestSuite struct { + suite.Suite + + pgEngine *pgengine.Engine + db *pgengine.DB +} + +func (suite *simpleMigratorTestSuite) mustGetTestDBPool() *sql.DB { + pool, err := sql.Open("pgx", suite.db.GetDSN()) + suite.NoError(err) + return pool +} + +func (suite *simpleMigratorTestSuite) mustGetTestDBConn() (conn *sql.Conn, poolCloser io.Closer) { + pool := suite.mustGetTestDBPool() + conn, err := pool.Conn(context.Background()) + suite.Require().NoError(err) + return conn, pool +} + +func (suite *simpleMigratorTestSuite) mustBuildTempDbFactory(ctx context.Context) tempdb.Factory { + tempDbFactory, err := tempdb.NewOnInstanceFactory(ctx, func(ctx context.Context, dbName string) (*sql.DB, error) { + return sql.Open("pgx", suite.pgEngine.GetPostgresDatabaseConnOpts().With("dbname", dbName).ToDSN()) + }) + suite.Require().NoError(err) + return tempDbFactory +} + +func (suite *simpleMigratorTestSuite) mustApplyDDLToTestDb(ddl []string) { + conn := suite.mustGetTestDBPool() + defer conn.Close() + + for _, stmt := range ddl { + _, err := conn.Exec(stmt) + suite.NoError(err) + } +} + +func (suite *simpleMigratorTestSuite) SetupSuite() { + engine, err := pgengine.StartEngine() + suite.Require().NoError(err) + suite.pgEngine = engine +} + +func (suite *simpleMigratorTestSuite) TearDownSuite() { + suite.pgEngine.Close() +} + +func (suite *simpleMigratorTestSuite) SetupTest() { + db, err := suite.pgEngine.CreateDatabase() + suite.NoError(err) + suite.db = db +} + +func (suite *simpleMigratorTestSuite) TearDownTest() { + suite.db.DropDB() +} + +func (suite *simpleMigratorTestSuite) TestPlanAndApplyMigration() { + initialDDL := ` + CREATE TABLE foobar( + id CHAR(16) PRIMARY KEY + ); ` + newSchemaDDL := ` + CREATE TABLE foobar( + id CHAR(16) PRIMARY KEY, + new_column VARCHAR(128) NOT NULL + ); + ` + + suite.mustApplyDDLToTestDb([]string{initialDDL}) + + conn, poolCloser := suite.mustGetTestDBConn() + defer poolCloser.Close() + defer conn.Close() + + tempDbFactory := suite.mustBuildTempDbFactory(context.Background()) + defer tempDbFactory.Close() + + plan, err := diff.GeneratePlan(context.Background(), conn, tempDbFactory, []string{newSchemaDDL}) + suite.NoError(err) + + // Run the migration + for _, stmt := range plan.Statements { + _, err = conn.ExecContext(context.Background(), stmt.ToSQL()) + suite.Require().NoError(err) + } + // Ensure that some sort of migration ran. we're really not testing the correctness of the + // migration in this test suite + _, err = conn.ExecContext(context.Background(), + "SELECT new_column FROM foobar;") + suite.NoError(err) +} + +func (suite *simpleMigratorTestSuite) TestCannotPackNewTablesWithoutIgnoringChangesToColumnOrder() { + tempDbFactory := suite.mustBuildTempDbFactory(context.Background()) + defer tempDbFactory.Close() + + conn, poolCloser := suite.mustGetTestDBConn() + defer poolCloser.Close() + defer conn.Close() + + _, err := diff.GeneratePlan(context.Background(), conn, tempDbFactory, []string{``}, diff.WithDataPackNewTables()) + suite.ErrorContains(err, "cannot data pack new tables without also ignoring changes to column order") +} + +func TestSimpleMigratorTestSuite(t *testing.T) { + suite.Run(t, new(simpleMigratorTestSuite)) +} diff --git a/pkg/diff/plan_test.go b/pkg/diff/plan_test.go new file mode 100644 index 0000000..f15b6a4 --- /dev/null +++ b/pkg/diff/plan_test.go @@ -0,0 +1,240 @@ +package diff_test + +import ( + "regexp" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +func TestPlan_ApplyStatementTimeoutModifier(t *testing.T) { + for _, tc := range []struct { + name string + regex string + timeout time.Duration + plan diff.Plan + expectedPlan diff.Plan + }{ + { + name: "no matches", + regex: "^.*x.*y$", + timeout: 2 * time.Hour, + plan: diff.Plan{ + Statements: []diff.Statement{ + { + DDL: "does-not-match-1", + Timeout: 3 * time.Second, + }, + { + DDL: "does-not-match-2", + Timeout: time.Second, + }, + { + DDL: "does-not-match-3", + Timeout: 2 * time.Second, + }, + }, + CurrentSchemaHash: "some-hash", + }, + expectedPlan: diff.Plan{ + Statements: []diff.Statement{ + { + DDL: "does-not-match-1", + Timeout: 3 * time.Second, + }, + { + DDL: "does-not-match-2", + Timeout: time.Second, + }, + { + DDL: "does-not-match-3", + Timeout: 2 * time.Second, + }, + }, + CurrentSchemaHash: "some-hash", + }, + }, + { + name: "some match", + regex: "^.*x.*y$", + timeout: 2 * time.Hour, + plan: diff.Plan{ + Statements: []diff.Statement{ + { + DDL: "some-letters-than-an-x--end-in-y", + Timeout: 3 * time.Second, + }, + { + DDL: "does-not-match", + Timeout: time.Second, + }, + { + DDL: "other-letters xerox but ends in finally", + Timeout: 2 * time.Second, + }, + }, + CurrentSchemaHash: "some-hash", + }, + expectedPlan: diff.Plan{ + Statements: []diff.Statement{ + { + DDL: "some-letters-than-an-x--end-in-y", + Timeout: 2 * time.Hour, + }, + { + DDL: "does-not-match", + Timeout: time.Second, + }, + { + DDL: "other-letters xerox but ends in finally", + Timeout: 2 * time.Hour, + }, + }, + CurrentSchemaHash: "some-hash", + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + regex := regexp.MustCompile(tc.regex) + resultingPlan := tc.plan.ApplyStatementTimeoutModifier(regex, tc.timeout) + assert.Equal(t, tc.expectedPlan, resultingPlan) + }) + } +} + +func TestPlan_InsertStatement(t *testing.T) { + var statementToInsert = diff.Statement{ + DDL: "some DDL", + Timeout: 3 * time.Second, + Hazards: []diff.MigrationHazard{ + {Type: diff.MigrationHazardTypeIsUserGenerated, Message: "user-generated"}, + {Type: diff.MigrationHazardTypeAcquiresShareLock, Message: "acquires share lock"}, + }, + } + + for _, tc := range []struct { + name string + plan diff.Plan + index int + + expectedPlan diff.Plan + expectedErrContains string + }{ + { + name: "empty plan", + plan: diff.Plan{ + Statements: nil, + CurrentSchemaHash: "some-hash", + }, + index: 0, + + expectedPlan: diff.Plan{ + Statements: []diff.Statement{ + statementToInsert, + }, + CurrentSchemaHash: "some-hash", + }, + }, + { + name: "insert at start", + plan: diff.Plan{ + Statements: []diff.Statement{ + {DDL: "statement 1", Timeout: time.Second}, + {DDL: "statement 2", Timeout: 2 * time.Second}, + {DDL: "statement 3", Timeout: 3 * time.Second}, + }, + CurrentSchemaHash: "some-hash", + }, + index: 0, + + expectedPlan: diff.Plan{ + Statements: []diff.Statement{ + statementToInsert, + {DDL: "statement 1", Timeout: time.Second}, + {DDL: "statement 2", Timeout: 2 * time.Second}, + {DDL: "statement 3", Timeout: 3 * time.Second}, + }, + CurrentSchemaHash: "some-hash", + }, + }, + { + name: "insert after first statement", + plan: diff.Plan{ + Statements: []diff.Statement{ + {DDL: "statement 1", Timeout: time.Second}, + {DDL: "statement 2", Timeout: 2 * time.Second}, + {DDL: "statement 3", Timeout: 3 * time.Second}, + }, + CurrentSchemaHash: "some-hash", + }, + index: 1, + + expectedPlan: diff.Plan{ + Statements: []diff.Statement{ + + {DDL: "statement 1", Timeout: time.Second}, + statementToInsert, + {DDL: "statement 2", Timeout: 2 * time.Second}, + {DDL: "statement 3", Timeout: 3 * time.Second}, + }, + CurrentSchemaHash: "some-hash", + }, + }, + { + name: "insert after last statement statement", + plan: diff.Plan{ + Statements: []diff.Statement{ + {DDL: "statement 1", Timeout: time.Second}, + {DDL: "statement 2", Timeout: 2 * time.Second}, + {DDL: "statement 3", Timeout: 3 * time.Second}, + }, + CurrentSchemaHash: "some-hash", + }, + index: 3, + + expectedPlan: diff.Plan{ + Statements: []diff.Statement{ + {DDL: "statement 1", Timeout: time.Second}, + {DDL: "statement 2", Timeout: 2 * time.Second}, + {DDL: "statement 3", Timeout: 3 * time.Second}, + statementToInsert, + }, + CurrentSchemaHash: "some-hash", + }, + }, + { + name: "errors on negative index", + plan: diff.Plan{ + Statements: nil, + CurrentSchemaHash: "some-hash", + }, + index: -1, + + expectedErrContains: "index must be", + }, + { + name: "errors on index greater than len", + plan: diff.Plan{ + Statements: nil, + CurrentSchemaHash: "some-hash", + }, + index: 1, + + expectedErrContains: "index must be", + }, + } { + t.Run(tc.name, func(t *testing.T) { + resultingPlan, err := tc.plan.InsertStatement(tc.index, statementToInsert) + if len(tc.expectedErrContains) > 0 { + assert.ErrorContains(t, err, tc.expectedErrContains) + return + } + require.NoError(t, err) + assert.Equal(t, tc.expectedPlan, resultingPlan) + }) + } +} diff --git a/pkg/diff/schema_migration_plan_test.go b/pkg/diff/schema_migration_plan_test.go new file mode 100644 index 0000000..b9596b0 --- /dev/null +++ b/pkg/diff/schema_migration_plan_test.go @@ -0,0 +1,1163 @@ +package diff + +import ( + "testing" + + "github.com/google/uuid" + "github.com/kr/pretty" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stripe/pg-schema-diff/internal/schema" +) + +type schemaMigrationPlanTestCase struct { + name string + oldSchema schema.Schema + newSchema schema.Schema + expectedStatements []Statement + expectedDiffErrIs error +} + +// schemaMigrationPlanTestCases -- these test cases assert the exact migration plan that is expected +// to be generated when migrating from the oldSchema to the newSchema. They assert how the migration +// should occur by asserting the DDL +// +// Most test cases should be added to //pg-schema-diff/internal/migration_acceptance_test_cases (acceptance +// tests) instead of here. +// +// The acceptance tests actually fetch the old/new schemas; run the migration; and validate the migration +// updates the old schema to be equivalent to the new schema. However, they do not assert any DDL; they have +// no expectation on how the migration should be done. +// +// The tests added here should just cover niche cases where you want to assert HOW the migration should be done (e.g., +// adding an index concurrently) +var ( + defaultCollation = schema.SchemaQualifiedName{ + EscapedName: `"default"`, + SchemaName: "pg_catalog", + } + cCollation = schema.SchemaQualifiedName{ + EscapedName: `"C"`, + SchemaName: "pg_catalog", + } + + schemaMigrationPlanTestCases = []schemaMigrationPlanTestCase{ + { + name: "No-op", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + {Name: "fizz", Type: "boolean", Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "id_check", Expression: "(id > 0)", IsInheritable: true}, + {Name: "bar_check", Expression: "(id > LENGTH(foo))", IsValid: true}, + }, + }, + { + Name: "bar", + Columns: []schema.Column{ + {Name: "id", Type: "character varying(255)", Default: "", IsNullable: false, Collation: defaultCollation}, + {Name: "foo", Type: "integer", Default: "", IsNullable: true}, + {Name: "bar", Type: "double precision", Default: "8.8", IsNullable: false}, + {Name: "fizz", Type: "timestamp with time zone", Default: "CURRENT_TIMESTAMP", IsNullable: true}, + {Name: "buzz", Type: "real", Default: "", IsNullable: false}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "id_check", Expression: "(id + 10 > 0 )", IsValid: true}, + {Name: "bar_check", Expression: "(foo > buzz)", IsInheritable: true}, + }, + }, + { + Name: "fizz", + Columns: nil, + CheckConstraints: nil, + }, + }, + Indexes: []schema.Index{ + // bar indexes + { + TableName: "bar", + Name: "bar_pkey", Columns: []string{"bar"}, IsPk: true, IsUnique: true, ConstraintName: "bar_pkey_non_default_name", + GetIndexDefStmt: "CREATE UNIQUE INDEX bar_pkey ON public.bar USING btree (bar)", + }, + { + TableName: "bar", + Name: "bar_normal_idx", Columns: []string{"fizz"}, + GetIndexDefStmt: "CREATE INDEX bar_normal_idx ON public.bar USING btree (fizz)", + }, + { + TableName: "bar", + Name: "bar_unique_idx", IsUnique: true, Columns: []string{"fizz", "buzz"}, + GetIndexDefStmt: "CREATE UNIQUE INDEX bar_unique_idx ON public.bar USING btree (fizz, buzz)", + }, + // foobar indexes + { + TableName: "foobar", + Name: "foobar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON public.foobar USING btree (id)", + }, + { + TableName: "foobar", + Name: "foobar_normal_idx", Columns: []string{"fizz"}, + GetIndexDefStmt: "CREATE INDEX foobar_normal_idx ON public.foobar USING btree (fizz)", + }, + { + TableName: "foobar", + Name: "foobar_unique_idx", IsUnique: true, Columns: []string{"foo", "bar"}, + GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_unique_idx ON public.foobar USING btree (foo, bar)", + }, + }, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + {Name: "fizz", Type: "boolean", Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "id_check", Expression: "(id > 0)", IsInheritable: true}, + {Name: "bar_check", Expression: "(id > LENGTH(foo))", IsValid: true}, + }, + }, + { + Name: "bar", + Columns: []schema.Column{ + {Name: "id", Type: "character varying(255)", Default: "", IsNullable: false, Collation: defaultCollation}, + {Name: "foo", Type: "integer", Default: "", IsNullable: true}, + {Name: "bar", Type: "double precision", Default: "8.8", IsNullable: false}, + {Name: "fizz", Type: "timestamp with time zone", Default: "CURRENT_TIMESTAMP", IsNullable: true}, + {Name: "buzz", Type: "real", Default: "", IsNullable: false}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "id_check", Expression: "(id + 10 > 0 )", IsValid: true}, + {Name: "bar_check", Expression: "(foo > buzz)", IsInheritable: true}, + }, + }, + { + Name: "fizz", + Columns: nil, + CheckConstraints: nil, + }, + }, + Indexes: []schema.Index{ + // bar indexes + { + TableName: "bar", + Name: "bar_pkey", Columns: []string{"bar"}, IsPk: true, IsUnique: true, ConstraintName: "bar_pkey_non_default_name", + GetIndexDefStmt: "CREATE UNIQUE INDEX bar_pkey ON public.bar USING btree (bar)", + }, + { + TableName: "bar", + Name: "bar_normal_idx", Columns: []string{"fizz"}, + GetIndexDefStmt: "CREATE INDEX bar_normal_idx ON public.bar USING btree (fizz)", + }, + { + TableName: "bar", + Name: "bar_unique_idx", IsUnique: true, Columns: []string{"fizz", "buzz"}, + GetIndexDefStmt: "CREATE UNIQUE INDEX bar_unique_idx ON public.bar USING btree (fizz, buzz)", + }, + // foobar indexes + { + TableName: "foobar", + Name: "foobar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON public.foobar USING btree (id)", + }, + { + TableName: "foobar", + Name: "foobar_normal_idx", Columns: []string{"fizz"}, + GetIndexDefStmt: "CREATE INDEX foobar_normal_idx ON public.foobar USING btree (fizz)", + }, + { + TableName: "foobar", + Name: "foobar_unique_idx", IsUnique: true, Columns: []string{"foo", "bar"}, + GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_unique_idx ON public.foobar USING btree (foo, bar)", + }, + }, + }, + expectedStatements: nil, + }, + { + name: "Index replacement", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + }, + }, + Indexes: []schema.Index{ + { + TableName: "foobar", + Name: "foo_idx", Columns: []string{"foo"}, + GetIndexDefStmt: "CREATE INDEX foo_idx ON public.foobar USING btree (foo)", + }, + { + TableName: "foobar", + Name: "replaced_with_same_name_idx", Columns: []string{"bar"}, + GetIndexDefStmt: "CREATE INDEX replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)", + }, + }, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + }, + }, + Indexes: []schema.Index{ + { + TableName: "foobar", + Name: "new_foo_idx", Columns: []string{"foo"}, + GetIndexDefStmt: "CREATE INDEX new_foo_idx ON public.foobar USING btree (foo)", + }, + { + TableName: "foobar", + Name: "replaced_with_same_name_idx", Columns: []string{"bar", "foo"}, + GetIndexDefStmt: "CREATE INDEX replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)", + }, + }, + }, + expectedStatements: []Statement{ + { + DDL: "ALTER INDEX \"replaced_with_same_name_idx\" RENAME TO \"replaced_with_same_name_id_00010203-0405-4607-8809-0a0b0c0d0e0f\"", + Timeout: statementTimeoutDefault, + Hazards: nil, + }, + { + DDL: "CREATE INDEX CONCURRENTLY new_foo_idx ON public.foobar USING btree (foo)", + Timeout: statementTimeoutConcurrentIndexBuild, + Hazards: []MigrationHazard{buildIndexBuildHazard()}, + }, + { + DDL: "CREATE INDEX CONCURRENTLY replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)", + Timeout: statementTimeoutConcurrentIndexBuild, + Hazards: []MigrationHazard{buildIndexBuildHazard()}, + }, + { + DDL: "DROP INDEX CONCURRENTLY \"foo_idx\"", + Timeout: statementTimeoutConcurrentIndexDrop, + Hazards: []MigrationHazard{ + {Type: "INDEX_DROPPED", Message: "Dropping this index means queries that use this index might perform worse because they will no longer will be able to leverage it."}, + }, + }, + { + DDL: "DROP INDEX CONCURRENTLY \"replaced_with_same_name_id_00010203-0405-4607-8809-0a0b0c0d0e0f\"", + Timeout: statementTimeoutConcurrentIndexDrop, + Hazards: []MigrationHazard{ + {Type: "INDEX_DROPPED", Message: "Dropping this index means queries that use this index might perform worse because they will no longer will be able to leverage it."}, + }, + }, + }, + }, + { + name: "Index dropped concurrently before columns dropped", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + }, + }, + Indexes: []schema.Index{ + { + TableName: "foobar", + Name: "foobar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_pkey ON public.foobar USING btree (id)", + }, + { + TableName: "foobar", + Name: "some_idx", Columns: []string{"foo, bar"}, + GetIndexDefStmt: "CREATE INDEX some_idx ON public.foobar USING btree (foo, bar)", + }, + }, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + }, + CheckConstraints: nil, + }, + }, + Indexes: []schema.Index{ + { + TableName: "foobar", + Name: "foobar_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_pkey ON public.foobar USING btree (id)", + }, + }, + }, + expectedStatements: []Statement{ + { + DDL: "DROP INDEX CONCURRENTLY \"some_idx\"", + Timeout: statementTimeoutConcurrentIndexDrop, + Hazards: []MigrationHazard{buildIndexDroppedQueryPerfHazard()}, + }, + { + DDL: "ALTER TABLE \"foobar\" DROP COLUMN \"bar\"", + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{buildColumnDataDeletionHazard()}, + }, + { + DDL: "ALTER TABLE \"foobar\" DROP COLUMN \"foo\"", + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{buildColumnDataDeletionHazard()}, + }, + }, + }, + { + name: "Invalid index re-created", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + }, + }, + Indexes: []schema.Index{ + { + TableName: "foobar", + Name: "some_idx", Columns: []string{"foo", "bar"}, + GetIndexDefStmt: "CREATE INDEX some_idx ON public.foobar USING btree (foo, bar)", + IsUnique: true, IsInvalid: true, + }, + }, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + }, + }, + Indexes: []schema.Index{ + + { + TableName: "foobar", + Name: "some_idx", Columns: []string{"foo", "bar"}, + GetIndexDefStmt: "CREATE INDEX some_idx ON public.foobar USING btree (foo, bar)", + IsUnique: true, + }, + }, + }, + expectedStatements: []Statement{ + { + DDL: "ALTER INDEX \"some_idx\" RENAME TO \"some_idx_10111213-1415-4617-9819-1a1b1c1d1e1f\"", + Timeout: statementTimeoutDefault, + }, + { + DDL: "CREATE INDEX CONCURRENTLY some_idx ON public.foobar USING btree (foo, bar)", + Timeout: statementTimeoutConcurrentIndexBuild, + Hazards: []MigrationHazard{buildIndexBuildHazard()}, + }, + { + DDL: "DROP INDEX CONCURRENTLY \"some_idx_10111213-1415-4617-9819-1a1b1c1d1e1f\"", + Timeout: statementTimeoutConcurrentIndexDrop, + Hazards: []MigrationHazard{buildIndexDroppedQueryPerfHazard()}, + }, + }, + }, + { + name: "Index replacement on partitioned table", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + PartitionKeyDef: "PARTITION BY LIST(foo)", + }, + { + ParentTableName: "foobar", + Name: "foobar_1", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some_val')", + }, + { + ParentTableName: "foobar", + Name: "foobar_2", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some_other_val')", + }, + }, + Indexes: []schema.Index{ + // foobar indexes + { + TableName: "foobar", + Name: "some_idx", Columns: []string{"foo", "bar"}, + GetIndexDefStmt: "CREATE INDEX some_idx ON ONLY public.foobar USING btree (foo, bar)", + }, + { + TableName: "foobar", + Name: "replaced_with_same_name_idx", Columns: []string{"bar"}, + GetIndexDefStmt: "CREATE INDEX replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)", + }, + // foobar_1 indexes + { + TableName: "foobar_1", + Name: "foobar_1_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "some_idx", + GetIndexDefStmt: "CREATE INDEX foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)", + }, + { + TableName: "foobar_1", + Name: "foobar_1_replaced_with_same_name_idx", Columns: []string{"bar"}, ParentIdxName: "replaced_with_same_name_idx", + GetIndexDefStmt: "CREATE INDEX foobar_1_replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)", + }, + { + TableName: "foobar_1", + Name: "foobar_1_some_local_idx", Columns: []string{"foo", "bar", "id"}, + GetIndexDefStmt: "CREATE INDEX foobar_1_some_local_idx ON public.foobar_1 USING btree (foo, bar, id)", + }, + // foobar_2 indexes + { + TableName: "foobar_2", + Name: "foobar_2_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "some_idx", + GetIndexDefStmt: "CREATE INDEX foobar_2_some_idx ON public.foobar_2 USING btree (foo, bar)", + }, + { + TableName: "foobar_2", + Name: "foobar_2_replaced_with_same_name_idx", Columns: []string{"bar"}, ParentIdxName: "replaced_with_same_name_idx", + GetIndexDefStmt: "CREATE INDEX foobar_2_replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar)", + }, + }, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + PartitionKeyDef: "PARTITION BY LIST(foo)", + }, + { + ParentTableName: "foobar", + Name: "foobar_1", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some_val')", + }, + { + ParentTableName: "foobar", + Name: "foobar_2", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some_other_val')", + }, + }, + Indexes: []schema.Index{ + // foobar indexes + { + TableName: "foobar", + Name: "new_some_idx", Columns: []string{"foo", "bar"}, + GetIndexDefStmt: "CREATE INDEX new_some_idx ON ONLY public.foobar USING btree (foo, bar)", + }, + { + TableName: "foobar", + Name: "replaced_with_same_name_idx", Columns: []string{"bar", "foo"}, + GetIndexDefStmt: "CREATE INDEX replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar, foo)", + }, + // foobar_1 indexes + { + TableName: "foobar_1", + Name: "new_foobar_1_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "new_some_idx", + GetIndexDefStmt: "CREATE INDEX new_foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)", + }, + { + TableName: "foobar_1", + Name: "foobar_1_replaced_with_same_name_idx", Columns: []string{"bar", "foo"}, ParentIdxName: "replaced_with_same_name_idx", + GetIndexDefStmt: "CREATE INDEX foobar_1_replaced_with_same_name_idx ON public.foobar USING btree (bar, foo)", + }, + { + TableName: "foobar_1", + Name: "new_foobar_1_some_local_idx", Columns: []string{"foo", "bar", "id"}, + GetIndexDefStmt: "CREATE INDEX new_foobar_1_some_local_idx ON public.foobar_1 USING btree (foo, bar, id)", + }, + // foobar_2 indexes + { + TableName: "foobar_2", + Name: "new_foobar_2_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "new_some_idx", + GetIndexDefStmt: "CREATE INDEX new_foobar_2_some_idx ON public.foobar_2 USING btree (foo, bar)", + }, + { + TableName: "foobar_2", + Name: "foobar_2_replaced_with_same_name_idx", Columns: []string{"bar", "foo"}, ParentIdxName: "replaced_with_same_name_idx", + GetIndexDefStmt: "CREATE INDEX foobar_2_replaced_with_same_name_idx ON public.foobar_2 USING btree (bar, foo)", + }, + }, + }, + expectedStatements: []Statement{ + { + DDL: "ALTER INDEX \"foobar_1_replaced_with_same_name_idx\" RENAME TO \"foobar_1_replaced_with_sam_30313233-3435-4637-b839-3a3b3c3d3e3f\"", + Timeout: statementTimeoutDefault, + }, + { + DDL: "ALTER INDEX \"foobar_2_replaced_with_same_name_idx\" RENAME TO \"foobar_2_replaced_with_sam_40414243-4445-4647-8849-4a4b4c4d4e4f\"", + Timeout: statementTimeoutDefault, + }, + { + DDL: "ALTER INDEX \"replaced_with_same_name_idx\" RENAME TO \"replaced_with_same_name_id_20212223-2425-4627-a829-2a2b2c2d2e2f\"", + Timeout: statementTimeoutDefault, + }, + { + DDL: "CREATE INDEX new_some_idx ON ONLY public.foobar USING btree (foo, bar)", + Timeout: statementTimeoutDefault, + }, + { + DDL: "CREATE INDEX replaced_with_same_name_idx ON ONLY public.foobar USING btree (bar, foo)", + Timeout: statementTimeoutDefault, + }, + { + DDL: "CREATE INDEX CONCURRENTLY foobar_1_replaced_with_same_name_idx ON public.foobar USING btree (bar, foo)", + Timeout: statementTimeoutConcurrentIndexDrop, + Hazards: []MigrationHazard{ + buildIndexBuildHazard(), + }, + }, + { + DDL: "ALTER INDEX \"replaced_with_same_name_idx\" ATTACH PARTITION \"foobar_1_replaced_with_same_name_idx\"", + Timeout: statementTimeoutDefault, + Hazards: nil, + }, + { + DDL: "CREATE INDEX CONCURRENTLY new_foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)", + Timeout: statementTimeoutConcurrentIndexBuild, + Hazards: []MigrationHazard{ + buildIndexBuildHazard(), + }, + }, + { + DDL: "ALTER INDEX \"new_some_idx\" ATTACH PARTITION \"new_foobar_1_some_idx\"", + Timeout: statementTimeoutDefault, + Hazards: nil, + }, + { + DDL: "CREATE INDEX CONCURRENTLY new_foobar_1_some_local_idx ON public.foobar_1 USING btree (foo, bar, id)", + Timeout: statementTimeoutConcurrentIndexBuild, + Hazards: []MigrationHazard{ + buildIndexBuildHazard(), + }, + }, + { + DDL: "CREATE INDEX CONCURRENTLY foobar_2_replaced_with_same_name_idx ON public.foobar_2 USING btree (bar, foo)", + Timeout: statementTimeoutConcurrentIndexBuild, + Hazards: []MigrationHazard{ + buildIndexBuildHazard(), + }, + }, + { + DDL: "ALTER INDEX \"replaced_with_same_name_idx\" ATTACH PARTITION \"foobar_2_replaced_with_same_name_idx\"", + Timeout: statementTimeoutDefault, + Hazards: nil, + }, + { + DDL: "CREATE INDEX CONCURRENTLY new_foobar_2_some_idx ON public.foobar_2 USING btree (foo, bar)", + Timeout: statementTimeoutConcurrentIndexBuild, + Hazards: []MigrationHazard{ + buildIndexBuildHazard(), + }, + }, + { + DDL: "ALTER INDEX \"new_some_idx\" ATTACH PARTITION \"new_foobar_2_some_idx\"", + Timeout: statementTimeoutDefault, + Hazards: nil, + }, + { + DDL: "DROP INDEX CONCURRENTLY \"foobar_1_some_local_idx\"", + Timeout: statementTimeoutConcurrentIndexDrop, + Hazards: []MigrationHazard{ + buildIndexDroppedQueryPerfHazard(), + }, + }, + { + DDL: "DROP INDEX \"replaced_with_same_name_id_20212223-2425-4627-a829-2a2b2c2d2e2f\"", + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{ + buildIndexDroppedAcquiresLockHazard(), + buildIndexDroppedQueryPerfHazard(), + }, + }, + { + DDL: "DROP INDEX \"some_idx\"", + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{ + buildIndexDroppedAcquiresLockHazard(), + buildIndexDroppedQueryPerfHazard(), + }, + }, + }, + }, + { + name: "Local Index dropped concurrently before columns dropped; partitioned index just dropped", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + PartitionKeyDef: "PARTITION BY LIST(foo)", + }, + { + ParentTableName: "foobar", + Name: "foobar_1", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some_val')", + }, + }, + Indexes: []schema.Index{ + // foobar indexes + { + TableName: "foobar", + Name: "foobar_pkey", Columns: []string{"foo", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_pkey ON ONLY public.foobar USING btree (foo, id)", + }, + { + TableName: "foobar", + Name: "some_idx", Columns: []string{"foo, bar"}, + GetIndexDefStmt: "CREATE INDEX some_idx ON ONLY public.foobar USING btree (foo, bar)", + }, + // foobar_1 indexes + { + TableName: "foobar_1", + Name: "foobar_1_pkey", Columns: []string{"foo", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey", ParentIdxName: "foobar_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_1_pkey ON public.foobar_1 USING btree (foo, id)", + }, + { + TableName: "foobar_1", + Name: "foobar_1_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "some_idx", + GetIndexDefStmt: "CREATE INDEX foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)", + }, + { + TableName: "foobar_1", + Name: "foobar_1_some_local_idx", Columns: []string{"foo", "bar", "id"}, + GetIndexDefStmt: "CREATE INDEX foobar_1_some_local_idx ON public.foobar_1 USING btree (foo, bar, id)", + }, + }, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + }, + CheckConstraints: nil, + PartitionKeyDef: "PARTITION BY LIST(foo)", + }, + { + ParentTableName: "foobar", + Name: "foobar_1", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some_val')", + }, + }, + Indexes: []schema.Index{ + // foobar indexes + { + TableName: "foobar", + Name: "foobar_pkey", Columns: []string{"foo", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_pkey ON ONLY public.foobar USING btree (foo, id)", + }, + // foobar_1 indexes + { + TableName: "foobar_1", + Name: "foobar_1_pkey", Columns: []string{"foo", "id"}, IsPk: true, IsUnique: true, ConstraintName: "foobar_pkey", ParentIdxName: "foobar_pkey", + GetIndexDefStmt: "CREATE UNIQUE INDEX foobar_1_pkey ON public.foobar_1 USING btree (foo, id)", + }, + }, + }, + expectedStatements: []Statement{ + { + DDL: "DROP INDEX CONCURRENTLY \"foobar_1_some_local_idx\"", + Timeout: statementTimeoutConcurrentIndexDrop, + Hazards: []MigrationHazard{ + buildIndexDroppedQueryPerfHazard(), + }, + }, + { + DDL: "DROP INDEX \"some_idx\"", + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{ + buildIndexDroppedAcquiresLockHazard(), + buildIndexDroppedQueryPerfHazard(), + }, + }, + { + DDL: "ALTER TABLE \"foobar\" DROP COLUMN \"bar\"", + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{ + buildColumnDataDeletionHazard(), + }, + }, + }, + }, + { + name: "Invalid index of partitioned index re-created but original index remains untouched", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + PartitionKeyDef: "PARTITION BY LIST(foo)", + }, + { + ParentTableName: "foobar", + Name: "foobar_1", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some_val')", + }, + }, + Indexes: []schema.Index{ + // foobar indexes + { + TableName: "foobar", + Name: "some_idx", Columns: []string{"foo, bar"}, + GetIndexDefStmt: "CREATE INDEX some_idx ON ONLY public.foobar USING btree (foo, bar)", + IsInvalid: true, + }, + // foobar_1 indexes + { + TableName: "foobar_1", + Name: "foobar_1_some_idx", Columns: []string{"foo", "bar"}, + GetIndexDefStmt: "CREATE INDEX foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)", + IsInvalid: true, + }, + }, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + PartitionKeyDef: "PARTITION BY LIST(foo)", + }, + { + ParentTableName: "foobar", + Name: "foobar_1", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "foo", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + {Name: "bar", Type: "timestamp without time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP"}, + }, + CheckConstraints: nil, + ForValues: "FOR VALUES IN ('some_val')", + }, + }, + Indexes: []schema.Index{ + // foobar indexes + { + TableName: "foobar", + Name: "some_idx", Columns: []string{"foo, bar"}, + GetIndexDefStmt: "CREATE INDEX some_idx ON ONLY public.foobar USING btree (foo, bar)", + }, + // foobar_1 indexes + { + TableName: "foobar_1", + Name: "foobar_1_some_idx", Columns: []string{"foo", "bar"}, ParentIdxName: "some_idx", + GetIndexDefStmt: "CREATE INDEX foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)", + }, + }, + }, + expectedStatements: []Statement{ + { + DDL: "ALTER INDEX \"foobar_1_some_idx\" RENAME TO \"foobar_1_some_idx_50515253-5455-4657-9859-5a5b5c5d5e5f\"", + Timeout: statementTimeoutDefault, + }, + { + DDL: "CREATE INDEX CONCURRENTLY foobar_1_some_idx ON public.foobar_1 USING btree (foo, bar)", + Timeout: statementTimeoutConcurrentIndexBuild, + Hazards: []MigrationHazard{ + buildIndexBuildHazard(), + }, + }, + { + DDL: "ALTER INDEX \"some_idx\" ATTACH PARTITION \"foobar_1_some_idx\"", + Timeout: statementTimeoutDefault, + }, + { + DDL: "DROP INDEX CONCURRENTLY \"foobar_1_some_idx_50515253-5455-4657-9859-5a5b5c5d5e5f\"", + Timeout: statementTimeoutConcurrentIndexDrop, + Hazards: []MigrationHazard{ + buildIndexDroppedQueryPerfHazard(), + }, + }, + }, + }, + { + name: "Fails on duplicate column in old schema", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "id", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + }, + CheckConstraints: nil, + }, + }, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + {Name: "something", Type: "character varying(255)", Default: "''::character varying", Collation: defaultCollation}, + }, + CheckConstraints: nil, + }, + }, + }, + expectedStatements: nil, + expectedDiffErrIs: errDuplicateIdentifier, + }, + { + name: "Invalid check constraint made valid", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "id_check", Expression: "(id > 0)", IsInheritable: true}, + }, + }, + }, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "id_check", Expression: "(id > 0)", IsInheritable: true, IsValid: true}, + }, + }, + }, + }, + expectedStatements: []Statement{ + { + DDL: "ALTER TABLE \"foobar\" VALIDATE CONSTRAINT \"id_check\"", + Timeout: statementTimeoutDefault, + Hazards: nil, + }, + }, + }, + { + name: "Invalid check constraint re-created if expression changes", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "id_check", Expression: "(id > 0)", IsInheritable: true}, + }, + }, + }, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "id", Type: "integer"}, + }, + CheckConstraints: []schema.CheckConstraint{ + {Name: "id_check", Expression: "(id < 0)", IsInheritable: true, IsValid: true}, + }, + }, + }, + }, + expectedStatements: []Statement{ + { + DDL: "ALTER TABLE \"foobar\" DROP CONSTRAINT \"id_check\"", + Timeout: statementTimeoutDefault, + }, + { + DDL: "ALTER TABLE \"foobar\" ADD CONSTRAINT \"id_check\" CHECK((id < 0))", + Timeout: statementTimeoutDefault, + }, + }, + }, + { + name: "BIGINT to TIMESTAMP type conversion", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "baz", Type: "bigint"}, + }, + CheckConstraints: nil, + }, + }, + Indexes: nil, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "baz", Type: "timestamp without time zone", Default: "current_timestamp"}, + }, + CheckConstraints: nil, + }, + }, + Indexes: nil, + }, + expectedStatements: []Statement{ + { + DDL: "ALTER TABLE \"foobar\" ALTER COLUMN \"baz\" SET DATA TYPE timestamp without time zone using to_timestamp(\"baz\" / 1000)", + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{{ + Type: MigrationHazardTypeAcquiresAccessExclusiveLock, + Message: "This will completely lock the table while the data is being " + + "re-written for a duration of time that scales with the size of your " + + "data. The values previously stored as BIGINT will be translated into a " + + "TIMESTAMP value via the PostgreSQL to_timestamp() function. This " + + "translation will assume that the values stored in BIGINT represent a " + + "millisecond epoch value.", + }}, + }, + { + DDL: "ANALYZE \"foobar\" (\"baz\")", + Timeout: statementTimeoutAnalyzeColumn, + Hazards: []MigrationHazard{buildAnalyzeColumnMigrationHazard()}, + }, + { + DDL: "ALTER TABLE \"foobar\" ALTER COLUMN \"baz\" SET DEFAULT current_timestamp", + Timeout: statementTimeoutDefault, + }, + }, + }, + { + name: "Collation migration and Type Migration", + oldSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "migrate_to_c_coll", Type: "text", Collation: defaultCollation}, + {Name: "migrate_type", Type: "text", Collation: defaultCollation}, + }, + CheckConstraints: nil, + }, + }, + }, + newSchema: schema.Schema{ + Tables: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "migrate_to_c_coll", Type: "text", Collation: cCollation}, + {Name: "migrate_type", Type: "character varying(255)", Collation: defaultCollation}, + }, + CheckConstraints: nil, + }, + }, + }, + expectedStatements: []Statement{ + { + DDL: "ALTER TABLE \"foobar\" ALTER COLUMN \"migrate_to_c_coll\" SET DATA TYPE text COLLATE \"pg_catalog\".\"C\" using \"migrate_to_c_coll\"::text", + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{buildColumnTypeChangeHazard()}, + }, + { + DDL: "ANALYZE \"foobar\" (\"migrate_to_c_coll\")", + Timeout: statementTimeoutAnalyzeColumn, + Hazards: []MigrationHazard{buildAnalyzeColumnMigrationHazard()}, + }, + { + DDL: "ALTER TABLE \"foobar\" ALTER COLUMN \"migrate_type\" SET DATA TYPE character varying(255) COLLATE \"pg_catalog\".\"default\" using \"migrate_type\"::character varying(255)", + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{buildColumnTypeChangeHazard()}, + }, + { + DDL: "ANALYZE \"foobar\" (\"migrate_type\")", + Timeout: statementTimeoutAnalyzeColumn, + Hazards: []MigrationHazard{buildAnalyzeColumnMigrationHazard()}, + }, + }, + }, + } +) + +type deterministicRandReader struct { + counter int8 +} + +func (r *deterministicRandReader) Read(p []byte) (int, error) { + for i := 0; i < len(p); i++ { + p[i] = byte(r.counter) + r.counter++ + } + return len(p), nil +} + +func TestSchemaMigrationPlanTest(t *testing.T) { + uuid.SetRand(&deterministicRandReader{}) + + for _, testCase := range schemaMigrationPlanTestCases { + t.Run(testCase.name, func(t *testing.T) { + schemaDiff, _, err := buildSchemaDiff(testCase.oldSchema, testCase.newSchema) + if testCase.expectedDiffErrIs != nil { + require.ErrorIs(t, err, testCase.expectedDiffErrIs) + } else { + require.NoError(t, err) + } + stmts, err := schemaSQLGenerator{}.Alter(schemaDiff) + require.NoError(t, err) + assert.Equal(t, testCase.expectedStatements, stmts, "actual:\n %# v", pretty.Formatter(stmts)) + }) + } +} + +func buildColumnDataDeletionHazard() MigrationHazard { + return MigrationHazard{ + Type: MigrationHazardTypeDeletesData, + Message: "Deletes all values in the column", + } +} + +func buildColumnTypeChangeHazard() MigrationHazard { + return MigrationHazard{ + Type: MigrationHazardTypeAcquiresAccessExclusiveLock, + Message: "This will completely lock the table while the data is being re-written. The duration of this " + + "conversion depends on if the type conversion is trivial or not. A non-trivial conversion will require a " + + "table rewrite. A trivial conversion is one where the binary values are coercible and the column contents " + + "are not changing.", + } +} + +func buildAnalyzeColumnMigrationHazard() MigrationHazard { + return MigrationHazard{ + Type: MigrationHazardTypeImpactsDatabasePerformance, + Message: "Running analyze will read rows from the table, putting increased load " + + "on the database and consuming database resources. It won't prevent reads/writes to " + + "the table, but it could affect performance when executing queries.", + } +} + +func buildIndexBuildHazard() MigrationHazard { + return MigrationHazard{ + Type: MigrationHazardTypeIndexBuild, + Message: "This might affect database performance. " + + "Concurrent index builds require a non-trivial amount of CPU, potentially affecting database performance. " + + "They also can take a while but do not lock out writes.", + } +} + +func buildIndexDroppedQueryPerfHazard() MigrationHazard { + return MigrationHazard{ + Type: MigrationHazardTypeIndexDropped, + Message: "Dropping this index means queries that use this index might perform worse because " + + "they will no longer will be able to leverage it.", + } +} + +func buildIndexDroppedAcquiresLockHazard() MigrationHazard { + return MigrationHazard{ + Type: MigrationHazardTypeAcquiresAccessExclusiveLock, + Message: "Index drops will lock out all accesses to the table. They should be fast", + } +} diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go new file mode 100644 index 0000000..ae03298 --- /dev/null +++ b/pkg/diff/sql_generator.go @@ -0,0 +1,1431 @@ +package diff + +import ( + "fmt" + "strings" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/stripe/pg-schema-diff/internal/schema" +) + +const ( + maxPostgresIdentifierSize = 63 + + statementTimeoutDefault = 3 * time.Second + // statementTimeoutConcurrentIndexBuild is the statement timeout for index builds. It may take a while to build + // the index. Since it doesn't take out locks, this shouldn't be a concern + statementTimeoutConcurrentIndexBuild = 20 * time.Minute + // statementTimeoutConcurrentIndexDrop is the statement timeout for concurrent index drops. This operation shouldn't + // take out locks except when changing table metadata, but it may take a while to complete, so give it a long + // timeout + statementTimeoutConcurrentIndexDrop = 20 * time.Minute + // statementTimeoutTableDrop is the statement timeout for table drops. It may a take a while to delete the data + // Since the table is being dropped, locks shouldn't be a concern + statementTimeoutTableDrop = 20 * time.Minute + // statementTimeoutAnalyzeColumn is the statement timeout for analyzing the column of a table + statementTimeoutAnalyzeColumn = 20 * time.Minute +) + +var ( + ErrColumnOrderingChanged = fmt.Errorf("column ordering changed: %w", ErrNotImplemented) + + migrationHazardAddAlterFunctionCannotTrackDependencies = MigrationHazard{ + Type: MigrationHazardTypeHasUntrackableDependencies, + Message: "Dependencies, i.e. other functions used in the function body, of non-sql functions cannot be tracked. " + + "As a result, we cannot guarantee that function dependencies are ordered properly relative to this " + + "statement. For adds, this means you need to ensure that all functions this function depends on are " + + "created/altered before this statement.", + } + migrationHazardIndexDroppedQueryPerf = MigrationHazard{ + Type: MigrationHazardTypeIndexDropped, + Message: "Dropping this index means queries that use this index might perform worse because " + + "they will no longer will be able to leverage it.", + } + migrationHazardIndexDroppedAcquiresLock = MigrationHazard{ + Type: MigrationHazardTypeAcquiresAccessExclusiveLock, + Message: "Index drops will lock out all accesses to the table. They should be fast", + } +) + +type oldAndNew[S schema.Object] struct { + old S + new S +} + +func (o oldAndNew[S]) GetNew() S { + return o.old +} + +func (o oldAndNew[S]) GetOld() S { + return o.new +} + +type ( + columnDiff struct { + oldAndNew[schema.Column] + oldOrdering int + newOrdering int + } + + checkConstraintDiff struct { + oldAndNew[schema.CheckConstraint] + } + + tableDiff struct { + oldAndNew[schema.Table] + columnsDiff listDiff[schema.Column, columnDiff] + checkConstraintDiff listDiff[schema.CheckConstraint, checkConstraintDiff] + } + + indexDiff struct { + oldAndNew[schema.Index] + } + + functionDiff struct { + oldAndNew[schema.Function] + } + + triggerDiff struct { + oldAndNew[schema.Trigger] + } +) + +type schemaDiff struct { + oldAndNew[schema.Schema] + tableDiffs listDiff[schema.Table, tableDiff] + indexDiffs listDiff[schema.Index, indexDiff] + functionDiffs listDiff[schema.Function, functionDiff] + triggerDiffs listDiff[schema.Trigger, triggerDiff] +} + +func (sd schemaDiff) resolveToSQL() ([]Statement, error) { + return schemaSQLGenerator{}.Alter(sd) +} + +// The procedure for DIFFING schemas and GENERATING/RESOLVING the SQL required to migrate the old schema to the new schema is +// described below: +// +// A schema follows a hierarchy: Schemas -> Tables -> Columns and Indexes +// Every level of the hierarchy can depend on other items at the same level (indexes depend on columns). +// A similar idea applies with constraints, including Foreign key constraints. Because constraints can have cross-table +// dependencies, they can be viewed at the same level as tables. This hierarchy becomes interwoven with partitions +// +// Diffing two sets of schema objects follows a common pattern: +// (DIFFING) +// 1. Diff two lists of schema objects (e.g., schemas, tables). An item is new if it's name is not present in the old list. +// An item is deleted if it's name is not present in the new list. Otherwise, an item might have been altered +// 2. For each potentially altered item, generate the diff between the old and new. This might involve diffing lists if they have +// nested items (recursing into step 1) +// (GENERATING/RESOLVING) +// 3. Generate the SQL required for the deleted items (ADDS), the new items, and the altered items. +// These items might have interwoven dependencies +// 4. Topologically sort the diffed items +// +// The diffing is handled by diffLists. Diffs lists takes two lists of schema objects identifies which are +// added, deleted, or potentially altered. If the items are potentially altered, it will pass the items +// to a callback which handles diffing the old and new versions. This callback might call into diff lists +// for items nested inside its hierarchy +// +// Generating the SQL for the resulting diff from the two lists of items is handled by the SQL(Vertex)Generators. +// Every schema object defines a SQL(Vertex)Generator. A SQL(Vertex)Generator generates the SQL required to add, delete, +// or alter a schema object. If altering a schema object, the SQL(Vertex)Generator is passed the diff generated by the callback in diffLists. +// The sqlGenerator just generates SQL, while the sqlVertexGenerator also defines dependencies that a schema object has +// on other schema objects + +func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { + tableDiffs, err := diffLists(old.Tables, new.Tables, buildTableDiff) + if err != nil { + return schemaDiff{}, false, fmt.Errorf("diffing tables: %w", err) + } + + newSchemaTablesByName := buildSchemaObjMap(new.Tables) + addedTablesByName := buildSchemaObjMap(tableDiffs.adds) + indexesDiff, err := diffLists(old.Indexes, new.Indexes, func(old, new schema.Index, oldIndex, newIndex int) (indexDiff, bool, error) { + return buildIndexDiff(newSchemaTablesByName, addedTablesByName, old, new, oldIndex, newIndex) + }) + if err != nil { + return schemaDiff{}, false, fmt.Errorf("diffing indexes: %w", err) + } + + functionDiffs, err := diffLists(old.Functions, new.Functions, func(old, new schema.Function, _, _ int) (functionDiff, bool, error) { + return functionDiff{ + oldAndNew[schema.Function]{ + old: old, + new: new, + }, + }, false, nil + }) + if err != nil { + return schemaDiff{}, false, fmt.Errorf("diffing functions: %w", err) + } + + triggerDiffs, err := diffLists(old.Triggers, new.Triggers, func(old, new schema.Trigger, _, _ int) (triggerDiff, bool, error) { + if _, isOnNewTable := addedTablesByName[new.OwningTableUnescapedName]; isOnNewTable { + // If the table is new, then it must be re-created (this occurs if the base table has been + // re-created). In other words, a trigger must be re-created if the owning table is re-created + return triggerDiff{}, true, nil + } + return triggerDiff{ + oldAndNew[schema.Trigger]{ + old: old, + new: new, + }, + }, false, nil + }) + if err != nil { + return schemaDiff{}, false, fmt.Errorf("diffing triggers: %w", err) + } + + return schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{ + old: old, + new: new, + }, + tableDiffs: tableDiffs, + indexDiffs: indexesDiff, + functionDiffs: functionDiffs, + triggerDiffs: triggerDiffs, + }, false, nil +} + +func buildTableDiff(oldTable, newTable schema.Table, _, _ int) (diff tableDiff, requiresRecreation bool, err error) { + if oldTable.IsPartitioned() != newTable.IsPartitioned() { + return tableDiff{}, true, nil + } else if oldTable.PartitionKeyDef != newTable.PartitionKeyDef { + // We won't support changing partition key def due to issues with requiresRecreation. + // + // BLUF of the problem: If you have a flattened hierarchy (partitions, materialized views) and the parent + // is re-created but the children are unchanged, the children need to be re-created. + // + // If we want to add support, then we need diffLists to identify if a parent has been re-created (or if parents have changed), + // so it knows to re-create the child. This problem becomes more acute when a child can belong to + // multiple parents, e.g., materialized views. Ultimately, it's a graph problem in diffLists that can + // be solved through a `getParents` function + // + // Until the above is implemented, we can't support requiresRecreation on any flattened hierarchies + return tableDiff{}, false, fmt.Errorf("changing partition key def: %w", ErrNotImplemented) + } + + if oldTable.ParentTableName != newTable.ParentTableName { + // Since diffLists doesn't handle re-creating hierarchies that change, we need to manually + // identify if the hierarchy has changed. This approach will NOT work if we support multiple layers + // of partitioning because it's possible the parent's parent changed but the parent remained the same + return tableDiff{}, true, nil + } + + columnsDiff, err := diffLists( + oldTable.Columns, + newTable.Columns, + func(old, new schema.Column, oldIndex, newIndex int) (columnDiff, bool, error) { + return columnDiff{ + oldAndNew: oldAndNew[schema.Column]{old: old, new: new}, + oldOrdering: oldIndex, + newOrdering: newIndex, + }, false, nil + }, + ) + if err != nil { + return tableDiff{}, false, fmt.Errorf("diffing columns: %w", err) + } + + checkConsDiff, err := diffLists( + oldTable.CheckConstraints, + newTable.CheckConstraints, + func(old, new schema.CheckConstraint, _, _ int) (checkConstraintDiff, bool, error) { + recreateConstraint := (old.Expression != new.Expression) || + (old.IsValid && !new.IsValid) || + (old.IsInheritable != new.IsInheritable) + return checkConstraintDiff{oldAndNew[schema.CheckConstraint]{old: old, new: new}}, + recreateConstraint, + nil + }, + ) + if err != nil { + return tableDiff{}, false, fmt.Errorf("diffing lists: %w", err) + } + + return tableDiff{ + oldAndNew: oldAndNew[schema.Table]{ + old: oldTable, + new: newTable, + }, + columnsDiff: columnsDiff, + checkConstraintDiff: checkConsDiff, + }, false, nil +} + +// buildIndexDiff builds the index diff +func buildIndexDiff(newSchemaTablesByName map[string]schema.Table, addedTablesByName map[string]schema.Table, old, new schema.Index, _, _ int) (diff indexDiff, requiresRecreation bool, err error) { + updatedOld := old + + if _, isOnNewTable := addedTablesByName[new.TableName]; isOnNewTable { + // If the table is new, then it must be re-created (this occurs if the base table has been + // re-created). In other words, an index must be re-created if the owning table is re-created + return indexDiff{}, true, nil + } + + if len(old.ParentIdxName) == 0 { + // If the old index didn't belong to a partitioned index (and the new index does), we can resolve the parent + // index name diff if the index now belongs to a partitioned index by attaching the index. + // We can't switch an index partition from one parent to another; in that instance, we must + // re-create the index + updatedOld.ParentIdxName = new.ParentIdxName + } + + if !new.IsPartitionOfIndex() && !old.IsPk && new.IsPk { + // If the old index is not part of a primary key and the new index is part of a primary key, + // the constraint name diff is resolvable by adding the index to the primary key. + // Partitioned indexes are the exception; for partitioned indexes that are + // primary keys, the indexes are created with the constraint on the base table and cannot + // be attached to the base index + // In the future, we can change this behavior to ONLY create the constraint on the base table + // and follow a similar paradigm to adding indexes + updatedOld.ConstraintName = new.ConstraintName + updatedOld.IsPk = new.IsPk + } + + if isOnPartitionedTable, err := isOnPartitionedTable(newSchemaTablesByName, new); err != nil { + return indexDiff{}, false, err + } else if isOnPartitionedTable && old.IsInvalid && !new.IsInvalid { + // If the index is a partitioned index, it can be made valid automatically by attaching the index partitions + // We don't need to re-create it. + updatedOld.IsInvalid = new.IsInvalid + } + + recreateIndex := !cmp.Equal(updatedOld, new) + return indexDiff{ + oldAndNew: oldAndNew[schema.Index]{ + old: old, new: new, + }, + }, recreateIndex, nil +} + +type schemaSQLGenerator struct{} + +func (schemaSQLGenerator) Alter(diff schemaDiff) ([]Statement, error) { + tablesInNewSchemaByName := buildSchemaObjMap(diff.new.Tables) + deletedTablesByName := buildSchemaObjMap(diff.tableDiffs.deletes) + + tableSQLVertexGenerator := tableSQLVertexGenerator{ + deletedTablesByName: deletedTablesByName, + tablesInNewSchemaByName: tablesInNewSchemaByName, + } + tableGraphs, err := diff.tableDiffs.resolveToSQLGraph(&tableSQLVertexGenerator) + if err != nil { + return nil, fmt.Errorf("resolving table sql graphs: %w", err) + } + + indexesInNewSchemaByTableName := make(map[string][]schema.Index) + for _, idx := range diff.new.Indexes { + indexesInNewSchemaByTableName[idx.TableName] = append(indexesInNewSchemaByTableName[idx.TableName], idx) + } + attachPartitionSQLVertexGenerator := attachPartitionSQLVertexGenerator{ + indexesInNewSchemaByTableName: indexesInNewSchemaByTableName, + } + attachPartitionGraphs, err := diff.tableDiffs.resolveToSQLGraph(&attachPartitionSQLVertexGenerator) + if err != nil { + return nil, fmt.Errorf("resolving attach partition sql graphs: %w", err) + } + + renameConflictingIndexSQLVertexGenerator := newRenameConflictingIndexSQLVertexGenerator(buildSchemaObjMap(diff.old.Indexes)) + renameConflictingIndexGraphs, err := diff.indexDiffs.resolveToSQLGraph(&renameConflictingIndexSQLVertexGenerator) + if err != nil { + return nil, fmt.Errorf("resolving renaming conflicting indexes: %w", err) + } + + indexSQLVertexGenerator := indexSQLVertexGenerator{ + deletedTablesByName: deletedTablesByName, + addedTablesByName: buildSchemaObjMap(diff.tableDiffs.adds), + tablesInNewSchemaByName: tablesInNewSchemaByName, + indexesInNewSchemaByName: buildSchemaObjMap(diff.new.Indexes), + indexRenamesByOldName: renameConflictingIndexSQLVertexGenerator.getRenames(), + } + indexGraphs, err := diff.indexDiffs.resolveToSQLGraph(&indexSQLVertexGenerator) + if err != nil { + return nil, fmt.Errorf("resolving index sql graphs: %w", err) + } + + functionsInNewSchemaByName := buildSchemaObjMap(diff.new.Functions) + + functionSQLVertexGenerator := functionSQLVertexGenerator{ + functionsInNewSchemaByName: functionsInNewSchemaByName, + } + functionGraphs, err := diff.functionDiffs.resolveToSQLGraph(&functionSQLVertexGenerator) + if err != nil { + return nil, fmt.Errorf("resolving function sql graphs: %w", err) + } + + triggerSQLVertexGenerator := triggerSQLVertexGenerator{ + functionsInNewSchemaByName: functionsInNewSchemaByName, + } + triggerGraphs, err := diff.triggerDiffs.resolveToSQLGraph(&triggerSQLVertexGenerator) + if err != nil { + return nil, fmt.Errorf("resolving trigger sql graphs: %w", err) + } + + if err := tableGraphs.union(attachPartitionGraphs); err != nil { + return nil, fmt.Errorf("unioning table and attach partition graphs: %w", err) + } + if err := tableGraphs.union(indexGraphs); err != nil { + return nil, fmt.Errorf("unioning table and index graphs: %w", err) + } + if err := tableGraphs.union(renameConflictingIndexGraphs); err != nil { + return nil, fmt.Errorf("unioning table and rename conflicting index graphs: %w", err) + } + if err := tableGraphs.union(functionGraphs); err != nil { + return nil, fmt.Errorf("unioning table and function graphs: %w", err) + } + if err := tableGraphs.union(triggerGraphs); err != nil { + return nil, fmt.Errorf("unioning table and trigger graphs: %w", err) + } + + return tableGraphs.toOrderedStatements() +} + +func buildSchemaObjMap[S schema.Object](s []S) map[string]S { + output := make(map[string]S) + for _, obj := range s { + output[obj.GetName()] = obj + } + return output +} + +type tableSQLVertexGenerator struct { + deletedTablesByName map[string]schema.Table + tablesInNewSchemaByName map[string]schema.Table +} + +var _ sqlVertexGenerator[schema.Table, tableDiff] = &tableSQLVertexGenerator{} + +func (t *tableSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) { + if table.IsPartition() { + if table.IsPartitioned() { + return nil, fmt.Errorf("partitioned partitions: %w", ErrNotImplemented) + } + if len(table.CheckConstraints) > 0 { + return nil, fmt.Errorf("check constraints on partitions: %w", ErrNotImplemented) + } + // We attach the partitions separately. So the partition must have all the same check constraints + // as the original table + table.CheckConstraints = append(table.CheckConstraints, t.tablesInNewSchemaByName[table.ParentTableName].CheckConstraints...) + } + + var stmts []Statement + + var columnDefs []string + for _, column := range table.Columns { + columnDefs = append(columnDefs, "\t"+buildColumnDefinition(column)) + } + createTableSb := strings.Builder{} + createTableSb.WriteString(fmt.Sprintf("CREATE TABLE %s (\n%s\n)", + schema.EscapeIdentifier(table.Name), + strings.Join(columnDefs, ",\n"), + )) + if table.IsPartitioned() { + createTableSb.WriteString(fmt.Sprintf("PARTITION BY %s", table.PartitionKeyDef)) + } + stmts = append(stmts, Statement{ + DDL: createTableSb.String(), + Timeout: statementTimeoutDefault, + }) + + csg := checkConstraintSQLGenerator{tableName: table.Name} + for _, checkCon := range table.CheckConstraints { + addConStmts, err := csg.Add(checkCon) + if err != nil { + return nil, fmt.Errorf("generating add check constraint statements for check constraint %s: %w", checkCon.Name, err) + } + // Remove hazards from statements since the table is brand new + stmts = append(stmts, stripMigrationHazards(addConStmts)...) + } + + return stmts, nil +} + +func (t *tableSQLVertexGenerator) Delete(table schema.Table) ([]Statement, error) { + if table.IsPartition() { + // Don't support dropping partitions without dropping the base table. This would be easy to implement, but we + // would need to add tests for it. + // + // The base table might be recreated, so check if its deleted rather than just checking if it does not exist in + // the new schema + if _, baseTableDropped := t.deletedTablesByName[table.ParentTableName]; !baseTableDropped { + return nil, fmt.Errorf("deleting partitions without dropping parent table: %w", ErrNotImplemented) + } + // It will be dropped when the parent table is dropped + return nil, nil + } + return []Statement{ + { + DDL: fmt.Sprintf("DROP TABLE %s", schema.EscapeIdentifier(table.Name)), + Timeout: statementTimeoutTableDrop, + Hazards: []MigrationHazard{{ + Type: MigrationHazardTypeDeletesData, + Message: "Deletes all rows in the table (and the table itself)", + }}, + }, + }, nil +} + +func (t *tableSQLVertexGenerator) Alter(diff tableDiff) ([]Statement, error) { + if diff.old.IsPartition() != diff.new.IsPartition() { + return nil, fmt.Errorf("changing a partition to no longer be a partition (or vice versa): %w", ErrNotImplemented) + } else if diff.new.IsPartition() { + return t.alterPartition(diff) + } + + if diff.old.PartitionKeyDef != diff.new.PartitionKeyDef { + return nil, fmt.Errorf("changing partition key def: %w", ErrNotImplemented) + } + + columnSQLGenerator := columnSQLGenerator{tableName: diff.new.Name} + columnGeneratedSQL, err := diff.columnsDiff.resolveToSQLGroupedByEffect(&columnSQLGenerator) + if err != nil { + return nil, fmt.Errorf("resolving index diff: %w", err) + } + + checkConSQLGenerator := checkConstraintSQLGenerator{tableName: diff.new.Name} + checkConGeneratedSQL, err := diff.checkConstraintDiff.resolveToSQLGroupedByEffect(&checkConSQLGenerator) + if err != nil { + return nil, fmt.Errorf("Resolving check constraints diff: %w", err) + } + + var stmts []Statement + stmts = append(stmts, checkConGeneratedSQL.Deletes...) + stmts = append(stmts, columnGeneratedSQL.Deletes...) + stmts = append(stmts, columnGeneratedSQL.Adds...) + stmts = append(stmts, checkConGeneratedSQL.Adds...) + stmts = append(stmts, columnGeneratedSQL.Alters...) + stmts = append(stmts, checkConGeneratedSQL.Alters...) + return stmts, nil +} + +func (t *tableSQLVertexGenerator) alterPartition(diff tableDiff) ([]Statement, error) { + if diff.old.ForValues != diff.new.ForValues { + return nil, fmt.Errorf("altering partition FOR VALUES: %w", ErrNotImplemented) + } + if !diff.checkConstraintDiff.isEmpty() { + return nil, fmt.Errorf("check constraints on partitions: %w", ErrNotImplemented) + } + + var stmts []Statement + // ColumnsDiff should only have nullability changes. Partitioned tables + // aren't concerned about old/new columns added + for _, colDiff := range diff.columnsDiff.alters { + if colDiff.old.IsNullable == colDiff.new.IsNullable { + continue + } + alterColumnPrefix := fmt.Sprintf("%s ALTER COLUMN %s", alterTablePrefix(diff.new.Name), schema.EscapeIdentifier(colDiff.new.Name)) + if colDiff.new.IsNullable { + stmts = append(stmts, Statement{ + DDL: fmt.Sprintf("%s DROP NOT NULL", alterColumnPrefix), + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{ + { + Type: MigrationHazardTypeAcquiresAccessExclusiveLock, + Message: "Marking a column as not null requires a full table scan, which will lock out " + + "writes on the partition", + }, + }, + }) + } else { + stmts = append(stmts, Statement{ + DDL: fmt.Sprintf("%s SET NOT NULL", alterColumnPrefix), + Timeout: statementTimeoutDefault, + }) + } + } + + return stmts, nil +} + +func (t *tableSQLVertexGenerator) GetSQLVertexId(table schema.Table) string { + return buildTableVertexId(table.Name) +} + +func (t *tableSQLVertexGenerator) GetAddAlterDependencies(table, _ schema.Table) []dependency { + deps := []dependency{ + mustRun(t.GetSQLVertexId(table), diffTypeAddAlter).after(t.GetSQLVertexId(table), diffTypeDelete), + } + + if table.IsPartition() { + deps = append(deps, + mustRun(t.GetSQLVertexId(table), diffTypeAddAlter).after(buildTableVertexId(table.ParentTableName), diffTypeAddAlter), + ) + } + return deps +} + +func (t *tableSQLVertexGenerator) GetDeleteDependencies(table schema.Table) []dependency { + var deps []dependency + if table.IsPartition() { + deps = append(deps, + mustRun(t.GetSQLVertexId(table), diffTypeDelete).after(buildTableVertexId(table.ParentTableName), diffTypeDelete), + ) + } + return deps +} + +type columnSQLGenerator struct { + tableName string +} + +func (csg *columnSQLGenerator) Add(column schema.Column) ([]Statement, error) { + return []Statement{{ + DDL: fmt.Sprintf("%s ADD COLUMN %s", alterTablePrefix(csg.tableName), buildColumnDefinition(column)), + Timeout: statementTimeoutDefault, + }}, nil +} + +func (csg *columnSQLGenerator) Delete(column schema.Column) ([]Statement, error) { + return []Statement{{ + DDL: fmt.Sprintf("%s DROP COLUMN %s", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(column.Name)), + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{ + { + Type: MigrationHazardTypeDeletesData, + Message: "Deletes all values in the column", + }, + }, + }}, nil +} + +func (csg *columnSQLGenerator) Alter(diff columnDiff) ([]Statement, error) { + if diff.oldOrdering != diff.newOrdering { + return nil, fmt.Errorf("old=%d; new=%d: %w", diff.oldOrdering, diff.newOrdering, ErrColumnOrderingChanged) + } + oldColumn, newColumn := diff.old, diff.new + var stmts []Statement + alterColumnPrefix := fmt.Sprintf("%s ALTER COLUMN %s", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(newColumn.Name)) + + if oldColumn.IsNullable != newColumn.IsNullable { + if newColumn.IsNullable { + stmts = append(stmts, Statement{ + DDL: fmt.Sprintf("%s DROP NOT NULL", alterColumnPrefix), + Timeout: statementTimeoutDefault, + }) + } else { + stmts = append(stmts, Statement{ + DDL: fmt.Sprintf("%s SET NOT NULL", alterColumnPrefix), + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{ + { + Type: MigrationHazardTypeAcquiresAccessExclusiveLock, + Message: "Marking a column as not null requires a full table scan, which will lock out writes", + }, + }, + }) + } + } + + if len(oldColumn.Default) > 0 && len(newColumn.Default) == 0 { + // Drop the default before type conversion. This will allow type conversions + // between incompatible types if the previous column has a default and the new column is dropping its default + stmts = append(stmts, Statement{ + DDL: fmt.Sprintf("%s DROP DEFAULT", alterColumnPrefix), + Timeout: statementTimeoutDefault, + }) + } + + if !strings.EqualFold(oldColumn.Type, newColumn.Type) || + !strings.EqualFold(oldColumn.Collation.GetFQEscapedName(), newColumn.Collation.GetFQEscapedName()) { + stmts = append(stmts, + []Statement{ + csg.generateTypeTransformationStatement( + alterColumnPrefix, + schema.EscapeIdentifier(newColumn.Name), + oldColumn.Type, + newColumn.Type, + newColumn.Collation, + ), + // When "SET TYPE" is used to alter a column, that column's statistics are removed, which could + // affect query plans. In order to mitigate the effect on queries, re-generate the statistics for the + // column before continuing with the migration. + { + DDL: fmt.Sprintf("ANALYZE %s (%s)", schema.EscapeIdentifier(csg.tableName), schema.EscapeIdentifier(newColumn.Name)), + Timeout: statementTimeoutAnalyzeColumn, + Hazards: []MigrationHazard{ + { + Type: MigrationHazardTypeImpactsDatabasePerformance, + Message: "Running analyze will read rows from the table, putting increased load " + + "on the database and consuming database resources. It won't prevent reads/writes to " + + "the table, but it could affect performance when executing queries.", + }, + }, + }, + }...) + } + + if oldColumn.Default != newColumn.Default && len(newColumn.Default) > 0 { + // Set the default after the type conversion. This will allow type conversions + // between incompatible types if the previous column has no default and the new column has a default + stmts = append(stmts, Statement{ + DDL: fmt.Sprintf("%s SET DEFAULT %s", alterColumnPrefix, newColumn.Default), + Timeout: statementTimeoutDefault, + }) + } + + return stmts, nil +} + +func (csg *columnSQLGenerator) generateTypeTransformationStatement( + prefix string, + name string, + oldType string, + newType string, + newTypeCollation schema.SchemaQualifiedName, +) Statement { + if strings.EqualFold(oldType, "bigint") && + strings.EqualFold(newType, "timestamp without time zone") { + return Statement{ + DDL: fmt.Sprintf("%s SET DATA TYPE %s using to_timestamp(%s / 1000)", + prefix, + newType, + name, + ), + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{{ + Type: MigrationHazardTypeAcquiresAccessExclusiveLock, + Message: "This will completely lock the table while the data is being " + + "re-written for a duration of time that scales with the size of your data. " + + "The values previously stored as BIGINT will be translated into a " + + "TIMESTAMP value via the PostgreSQL to_timestamp() function. This " + + "translation will assume that the values stored in BIGINT represent a " + + "millisecond epoch value.", + }}, + } + } + + collationModifier := "" + if !newTypeCollation.IsEmpty() { + collationModifier = fmt.Sprintf("COLLATE %s ", newTypeCollation.GetFQEscapedName()) + } + + return Statement{ + DDL: fmt.Sprintf("%s SET DATA TYPE %s %susing %s::%s", + prefix, + newType, + collationModifier, + name, + newType, + ), + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{{ + Type: MigrationHazardTypeAcquiresAccessExclusiveLock, + Message: "This will completely lock the table while the data is being re-written. " + + "The duration of this conversion depends on if the type conversion is trivial " + + "or not. A non-trivial conversion will require a table rewrite. A trivial " + + "conversion is one where the binary values are coercible and the column " + + "contents are not changing.", + }}, + } +} + +type renameConflictingIndexSQLVertexGenerator struct { + // indexesInOldSchemaByName is a map of index name to the index in the old schema + // It is used to identify if an index has been re-created + oldSchemaIndexesByName map[string]schema.Index + + indexRenamesByOldName map[string]string +} + +func newRenameConflictingIndexSQLVertexGenerator(oldSchemaIndexesByName map[string]schema.Index) renameConflictingIndexSQLVertexGenerator { + return renameConflictingIndexSQLVertexGenerator{ + oldSchemaIndexesByName: oldSchemaIndexesByName, + indexRenamesByOldName: make(map[string]string), + } +} + +func (rsg *renameConflictingIndexSQLVertexGenerator) Add(index schema.Index) ([]Statement, error) { + if oldIndex, indexIsBeingRecreated := rsg.oldSchemaIndexesByName[index.Name]; !indexIsBeingRecreated { + return nil, nil + } else if oldIndex.IsPk && index.IsPk { + // Don't bother renaming if both are primary keys, since the new index will need to be created after the old + // index because we can't have two primary keys at the same time. + // + // To make changing primary keys index-gap free (mostly online), we could build the underlying new primary key index, + // drop the old primary constraint (and index), and then add the primary key constraint using the new index. + // This would require us to split primary key constraint SQL generation from index SQL generation + return nil, nil + } + + newName, err := rsg.generateNonConflictingName(index) + if err != nil { + return nil, fmt.Errorf("generating non-conflicting name: %w", err) + } + + rsg.indexRenamesByOldName[index.Name] = newName + + return []Statement{{ + DDL: fmt.Sprintf("ALTER INDEX %s RENAME TO %s", schema.EscapeIdentifier(index.Name), schema.EscapeIdentifier(newName)), + Timeout: statementTimeoutDefault, + }}, nil +} + +func (rsg *renameConflictingIndexSQLVertexGenerator) generateNonConflictingName(index schema.Index) (string, error) { + uuid, err := uuid.NewRandom() + if err != nil { + return "", fmt.Errorf("generating UUID: %w", err) + } + + newNameSuffix := fmt.Sprintf("_%s", uuid.String()) + idxNameTruncationIdx := len(index.Name) + if len(index.Name) > maxPostgresIdentifierSize-len(newNameSuffix) { + idxNameTruncationIdx = maxPostgresIdentifierSize - len(newNameSuffix) + } + + return index.Name[:idxNameTruncationIdx] + newNameSuffix, nil +} + +func (rsg *renameConflictingIndexSQLVertexGenerator) getRenames() map[string]string { + return rsg.indexRenamesByOldName +} + +func (rsg *renameConflictingIndexSQLVertexGenerator) Delete(_ schema.Index) ([]Statement, error) { + return nil, nil +} + +func (rsg *renameConflictingIndexSQLVertexGenerator) Alter(_ indexDiff) ([]Statement, error) { + return nil, nil +} + +func (*renameConflictingIndexSQLVertexGenerator) GetSQLVertexId(index schema.Index) string { + return buildRenameConflictingIndexVertexId(index.Name) +} + +func (rsg *renameConflictingIndexSQLVertexGenerator) GetAddAlterDependencies(_, _ schema.Index) []dependency { + return nil +} + +func (rsg *renameConflictingIndexSQLVertexGenerator) GetDeleteDependencies(_ schema.Index) []dependency { + return nil +} + +func buildRenameConflictingIndexVertexId(indexName string) string { + return buildVertexId("indexrename", indexName) +} + +type indexSQLVertexGenerator struct { + // deletedTablesByName is a map of table name to the deleted tables (and partitions) + deletedTablesByName map[string]schema.Table + // addedTablesByName is a map of table name to the new tables (and partitions) + // This is used to identify if hazards are necessary + addedTablesByName map[string]schema.Table + // tablesInNewSchemaByName is a map of table name to tables (and partitions) in the new schema. + // These tables are not necessarily new. This is used to identify if the table is partitioned + tablesInNewSchemaByName map[string]schema.Table + // indexesInNewSchemaByName is a map of index name to the index + // This is used to identify the parent index is a primary key + indexesInNewSchemaByName map[string]schema.Index + // indexRenamesByOldName is a map of any renames performed by the conflicting index sql vertex generator + indexRenamesByOldName map[string]string +} + +func (isg *indexSQLVertexGenerator) Add(index schema.Index) ([]Statement, error) { + stmts, err := isg.addIdxStmtsWithHazards(index) + if err != nil { + return stmts, err + } + + if _, isNewTable := isg.addedTablesByName[index.TableName]; isNewTable { + stmts = stripMigrationHazards(stmts) + } + return stmts, nil +} + +func (isg *indexSQLVertexGenerator) addIdxStmtsWithHazards(index schema.Index) ([]Statement, error) { + if index.IsInvalid { + return nil, fmt.Errorf("can't create an invalid index: %w", ErrNotImplemented) + } + + if index.IsPk { + if index.IsPartitionOfIndex() { + if parentIdx, ok := isg.indexesInNewSchemaByName[index.ParentIdxName]; !ok { + return nil, fmt.Errorf("could not find parent index %s", index.ParentIdxName) + } else if parentIdx.IsPk { + // All indexes associated with parent primary keys are automatically created by their parent + return nil, nil + } + } + + if isOnPartitionedTable, err := isg.isOnPartitionedTable(index); err != nil { + return nil, err + } else if isOnPartitionedTable { + // A partitioned table can't have a constraint added to it with "USING INDEX", so just make the index + // automatically through the constraint. This currently blocks and is a dangerous operation + // If users want to be able to switch primary keys concurrently, support can be added in the future, + // with a similar strategy to adding indexes to a partitioned table + return []Statement{ + { + DDL: fmt.Sprintf("%s ADD CONSTRAINT %s PRIMARY KEY (%s)", + alterTablePrefix(index.TableName), + schema.EscapeIdentifier(index.Name), + strings.Join(formattedNamesForSQL(index.Columns), ", "), + ), + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{ + { + Type: MigrationHazardTypeAcquiresShareLock, + Message: "This will lock writes to the table while the index build occurs.", + }, + { + Type: MigrationHazardTypeIndexBuild, + Message: "This is non-concurrent because adding PK's concurrently hasn't been" + + "implemented yet. It WILL lock out writes. Index builds require a non-trivial " + + "amount of CPU as well, which might affect database performance", + }, + }, + }, + }, nil + } + } + + var stmts []Statement + var createIdxStmtHazards []MigrationHazard + + createIdxStmt := string(index.GetIndexDefStmt) + createIdxStmtTimeout := statementTimeoutDefault + if isOnPartitionedTable, err := isg.isOnPartitionedTable(index); err != nil { + return nil, err + } else if !isOnPartitionedTable { + // Only indexes on non-partitioned tables can be created concurrently + concurrentCreateIdxStmt, err := index.GetIndexDefStmt.ToCreateIndexConcurrently() + if err != nil { + return nil, fmt.Errorf("modifying index def statement to concurrently: %w", err) + } + createIdxStmt = concurrentCreateIdxStmt + createIdxStmtHazards = append(createIdxStmtHazards, MigrationHazard{ + Type: MigrationHazardTypeIndexBuild, + Message: "This might affect database performance. " + + "Concurrent index builds require a non-trivial amount of CPU, potentially affecting database performance. " + + "They also can take a while but do not lock out writes.", + }) + createIdxStmtTimeout = statementTimeoutConcurrentIndexBuild + } + + stmts = append(stmts, Statement{ + DDL: createIdxStmt, + Timeout: createIdxStmtTimeout, + Hazards: createIdxStmtHazards, + }) + + _, isNewTable := isg.addedTablesByName[index.TableName] + if index.IsPartitionOfIndex() && !isNewTable { + // Exclude if the partition is new because the index will be attached when the partition is attached + stmts = append(stmts, buildAttachIndex(index)) + } + + if index.IsPk { + stmts = append(stmts, isg.addPkConstraintUsingIdx(index)) + } else if len(index.ConstraintName) > 0 { + return nil, fmt.Errorf("constraints not supported for non-primary key indexes: %w", ErrNotImplemented) + } + return stmts, nil +} + +func (isg *indexSQLVertexGenerator) Delete(index schema.Index) ([]Statement, error) { + _, tableWasDeleted := isg.deletedTablesByName[index.TableName] + // An index will be dropped if its owning table is dropped. + // Similarly, a partition of an index will be dropped when the parent index is dropped + if tableWasDeleted || index.IsPartitionOfIndex() { + return nil, nil + } + + // An index used by a primary key constraint/unique constraint cannot be dropped concurrently + if len(index.ConstraintName) > 0 { + // The index has been potentially renamed, which causes the constraint to be renamed. Use the updated name + constraintName := index.ConstraintName + if rename, hasRename := isg.indexRenamesByOldName[index.Name]; hasRename { + constraintName = rename + } + + // Dropping the constraint will automatically drop the index. There is no way to drop + // the constraint without dropping the index + return []Statement{ + { + DDL: dropConstraintDDL(index.TableName, constraintName), + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{ + migrationHazardIndexDroppedAcquiresLock, + migrationHazardIndexDroppedQueryPerf, + }, + }, + }, nil + } + + var dropIndexStmtHazards []MigrationHazard + concurrentlyModifier := "CONCURRENTLY " + dropIndexStmtTimeout := statementTimeoutConcurrentIndexDrop + if isOnPartitionedTable, err := isg.isOnPartitionedTable(index); err != nil { + return nil, err + } else if isOnPartitionedTable { + // Currently, postgres has no good way of dropping an index partition concurrently + concurrentlyModifier = "" + dropIndexStmtTimeout = statementTimeoutDefault + // Technically, CONCURRENTLY also locks the table, but it waits for an "opportunity" to lock + // We will omit the locking hazard of concurrent drops for now + dropIndexStmtHazards = append(dropIndexStmtHazards, migrationHazardIndexDroppedAcquiresLock) + } + + // The index has been potentially renamed. Use the updated name + indexName := index.Name + if rename, hasRename := isg.indexRenamesByOldName[index.Name]; hasRename { + indexName = rename + } + + return []Statement{{ + DDL: fmt.Sprintf("DROP INDEX %s%s", concurrentlyModifier, schema.EscapeIdentifier(indexName)), + Timeout: dropIndexStmtTimeout, + Hazards: append(dropIndexStmtHazards, migrationHazardIndexDroppedQueryPerf), + }}, nil +} + +func (isg *indexSQLVertexGenerator) Alter(diff indexDiff) ([]Statement, error) { + var stmts []Statement + + if isOnPartitionedTable, err := isg.isOnPartitionedTable(diff.new); err != nil { + return nil, err + } else if isOnPartitionedTable && diff.old.IsInvalid && !diff.new.IsInvalid { + // If the index is a partitioned index, it can be made valid automatically by attaching the index partitions + diff.old.IsInvalid = diff.new.IsInvalid + } + + if !diff.new.IsPartitionOfIndex() && !diff.old.IsPk && diff.new.IsPk { + stmts = append(stmts, isg.addPkConstraintUsingIdx(diff.new)) + diff.old.IsPk = diff.new.IsPk + diff.old.ConstraintName = diff.new.ConstraintName + } + + if len(diff.old.ParentIdxName) == 0 && len(diff.new.ParentIdxName) > 0 { + stmts = append(stmts, buildAttachIndex(diff.new)) + diff.old.ParentIdxName = diff.new.ParentIdxName + } + + if !cmp.Equal(diff.old, diff.new) { + return nil, fmt.Errorf("index diff could not be resolved %s", cmp.Diff(diff.old, diff.new)) + } + + return stmts, nil +} + +func (isg *indexSQLVertexGenerator) isOnPartitionedTable(index schema.Index) (bool, error) { + return isOnPartitionedTable(isg.tablesInNewSchemaByName, index) +} + +// Returns true if the table the index belongs too is partitioned. If the table is a partition of a +// partitioned table, this will always return false +func isOnPartitionedTable(tablesInNewSchemaByName map[string]schema.Table, index schema.Index) (bool, error) { + if owningTable, ok := tablesInNewSchemaByName[index.TableName]; !ok { + return false, fmt.Errorf("could not find table in new schema with name %s", index.TableName) + } else { + return owningTable.IsPartitioned(), nil + } +} + +func (isg *indexSQLVertexGenerator) addPkConstraintUsingIdx(index schema.Index) Statement { + return Statement{ + DDL: fmt.Sprintf("%s ADD CONSTRAINT %s PRIMARY KEY USING INDEX %s", alterTablePrefix(index.TableName), schema.EscapeIdentifier(index.ConstraintName), schema.EscapeIdentifier(index.Name)), + Timeout: statementTimeoutDefault, + } +} + +func buildAttachIndex(index schema.Index) Statement { + return Statement{ + DDL: fmt.Sprintf("ALTER INDEX %s ATTACH PARTITION %s", schema.EscapeIdentifier(index.ParentIdxName), schema.EscapeIdentifier(index.Name)), + Timeout: statementTimeoutDefault, + } +} + +func (*indexSQLVertexGenerator) GetSQLVertexId(index schema.Index) string { + return buildIndexVertexId(index.Name) +} + +func (isg *indexSQLVertexGenerator) GetAddAlterDependencies(index, _ schema.Index) []dependency { + dependencies := []dependency{ + mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).after(buildTableVertexId(index.TableName), diffTypeAddAlter), + // To allow for online changes to indexes, rename the older version of the index (if it exists) before the new version is added + mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).after(buildRenameConflictingIndexVertexId(index.Name), diffTypeAddAlter), + } + + if index.IsPartitionOfIndex() { + // Partitions of indexes must be created after the parent index is created + dependencies = append(dependencies, + mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).after(buildIndexVertexId(index.ParentIdxName), diffTypeAddAlter)) + } + + return dependencies +} + +func (isg *indexSQLVertexGenerator) GetDeleteDependencies(index schema.Index) []dependency { + dependencies := []dependency{ + mustRun(isg.GetSQLVertexId(index), diffTypeDelete).after(buildTableVertexId(index.TableName), diffTypeDelete), + // Drop the index after it has been potentially renamed + mustRun(isg.GetSQLVertexId(index), diffTypeDelete).after(buildRenameConflictingIndexVertexId(index.Name), diffTypeAddAlter), + } + + if index.IsPartitionOfIndex() { + // Since dropping the parent index will cause the partition of the index to drop, the parent drop should come + // before + dependencies = append(dependencies, + mustRun(isg.GetSQLVertexId(index), diffTypeDelete).after(buildIndexVertexId(index.ParentIdxName), diffTypeDelete)) + } + dependencies = append(dependencies, isg.addDepsOnTableAddAlterIfNecessary(index)...) + + return dependencies +} + +func (isg *indexSQLVertexGenerator) addDepsOnTableAddAlterIfNecessary(index schema.Index) []dependency { + // This could be cleaner if start sorting columns separately in the graph + + parentTable, ok := isg.tablesInNewSchemaByName[index.TableName] + if !ok { + // If the parent table is deleted, we don't need to worry about making the index statement come + // before any alters + return nil + } + + // These dependencies will force the index deletion statement to come before the table AddAlter + addAlterColumnDeps := []dependency{ + mustRun(isg.GetSQLVertexId(index), diffTypeDelete).before(buildTableVertexId(index.TableName), diffTypeAddAlter), + } + if len(parentTable.ParentTableName) > 0 { + // If the table is partitioned, columns modifications occur on the base table not the children. Thus, we + // need the dependency to also be on the parent table add/alter statements + addAlterColumnDeps = append( + addAlterColumnDeps, + mustRun(isg.GetSQLVertexId(index), diffTypeDelete).before(buildTableVertexId(parentTable.ParentTableName), diffTypeAddAlter), + ) + } + + // If the parent table still exists and the index is a primary key, we should drop the PK index before + // any statements associated with altering the table run. This is important for changing the nullability of + // columns + if index.IsPk { + return addAlterColumnDeps + } + + parentTableColumnsByName := buildSchemaObjMap(parentTable.Columns) + for _, idxColumn := range index.Columns { + // We need to force the index drop to come before the statements to drop columns. Otherwise, the columns + // drops will force the index to drop non-concurrently + if _, columnStillPresent := parentTableColumnsByName[idxColumn]; !columnStillPresent { + return addAlterColumnDeps + } + } + + return nil +} + +type checkConstraintSQLGenerator struct { + tableName string +} + +func (csg *checkConstraintSQLGenerator) Add(con schema.CheckConstraint) ([]Statement, error) { + // UDF's in check constraints are a bad idea. Check constraints are not re-validated + // if the UDF changes, so it's not really a safe practice. We won't support it for now + if len(con.DependsOnFunctions) > 0 { + return nil, fmt.Errorf("check constraints that depend on UDFs: %w", ErrNotImplemented) + } + + sb := strings.Builder{} + sb.WriteString(fmt.Sprintf("%s ADD CONSTRAINT %s CHECK(%s)", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(con.Name), con.Expression)) + if !con.IsInheritable { + sb.WriteString(" NO INHERIT") + } + + if !con.IsValid { + sb.WriteString(" NOT VALID") + } + + return []Statement{{ + DDL: sb.String(), + Timeout: statementTimeoutDefault, + }}, nil +} + +func (csg *checkConstraintSQLGenerator) Delete(con schema.CheckConstraint) ([]Statement, error) { + // We won't support deleting check constraints depending on UDF's to align with not supporting adding check + // constraints that depend on UDF's + if len(con.DependsOnFunctions) > 0 { + return nil, fmt.Errorf("check constraints that depend on UDFs: %w", ErrNotImplemented) + } + + return []Statement{{ + DDL: dropConstraintDDL(csg.tableName, con.Name), + Timeout: statementTimeoutDefault, + }}, nil +} + +func (csg *checkConstraintSQLGenerator) Alter(diff checkConstraintDiff) ([]Statement, error) { + if cmp.Equal(diff.old, diff.new) { + return nil, nil + } + + oldCopy := diff.old + oldCopy.IsValid = diff.new.IsValid + if !cmp.Equal(oldCopy, diff.new) { + // Technically, we could support altering expression, but I don't see the use case for it. it would require more test + // cases than forceReadding it, and I'm not convinced it unlocks any functionality + return nil, fmt.Errorf("altering check constraint to resolve the following diff %s: %w", cmp.Diff(oldCopy, diff.new), ErrNotImplemented) + } else if diff.old.IsValid && !diff.new.IsValid { + return nil, fmt.Errorf("check constraint can't go from invalid to valid") + } else if len(diff.old.DependsOnFunctions) > 0 || len(diff.new.DependsOnFunctions) > 0 { + return nil, fmt.Errorf("check constraints that depend on UDFs: %w", ErrNotImplemented) + } + + return []Statement{{ + DDL: fmt.Sprintf("%s VALIDATE CONSTRAINT %s", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(diff.old.Name)), + Timeout: statementTimeoutDefault, + }}, nil +} + +type attachPartitionSQLVertexGenerator struct { + indexesInNewSchemaByTableName map[string][]schema.Index +} + +func (*attachPartitionSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) { + if !table.IsPartition() { + return nil, nil + } + return []Statement{buildAttachPartitionStatement(table)}, nil +} + +func (*attachPartitionSQLVertexGenerator) Alter(_ tableDiff) ([]Statement, error) { + return nil, nil +} + +func buildAttachPartitionStatement(table schema.Table) Statement { + return Statement{ + DDL: fmt.Sprintf("%s ATTACH PARTITION %s %s", alterTablePrefix(table.ParentTableName), schema.EscapeIdentifier(table.Name), table.ForValues), + Timeout: statementTimeoutDefault, + } +} + +func (*attachPartitionSQLVertexGenerator) Delete(_ schema.Table) ([]Statement, error) { + return nil, nil +} + +func (*attachPartitionSQLVertexGenerator) GetSQLVertexId(table schema.Table) string { + return fmt.Sprintf("attachpartition_%s", table.Name) +} + +func (a *attachPartitionSQLVertexGenerator) GetAddAlterDependencies(table, _ schema.Table) []dependency { + deps := []dependency{ + mustRun(a.GetSQLVertexId(table), diffTypeAddAlter).after(buildTableVertexId(table.Name), diffTypeAddAlter), + } + + for _, idx := range a.indexesInNewSchemaByTableName[table.Name] { + deps = append(deps, mustRun(a.GetSQLVertexId(table), diffTypeAddAlter).after(buildIndexVertexId(idx.Name), diffTypeAddAlter)) + } + return deps +} + +func (a *attachPartitionSQLVertexGenerator) GetDeleteDependencies(_ schema.Table) []dependency { + return nil +} + +type functionSQLVertexGenerator struct { + // functionsInNewSchemaByName is a map of function new to functions in the new schema. + // These functions are not necessarily new + functionsInNewSchemaByName map[string]schema.Function +} + +func (f *functionSQLVertexGenerator) Add(function schema.Function) ([]Statement, error) { + var hazards []MigrationHazard + if !canFunctionDependenciesBeTracked(function) { + hazards = append(hazards, migrationHazardAddAlterFunctionCannotTrackDependencies) + } + return []Statement{{ + DDL: function.FunctionDef, + Timeout: statementTimeoutDefault, + Hazards: hazards, + }}, nil +} + +func (f *functionSQLVertexGenerator) Delete(function schema.Function) ([]Statement, error) { + var hazards []MigrationHazard + if !canFunctionDependenciesBeTracked(function) { + hazards = append(hazards, MigrationHazard{ + Type: MigrationHazardTypeHasUntrackableDependencies, + Message: "Dependencies, i.e. other functions used in the function body, of non-sql functions cannot be " + + "tracked. As a result, we cannot guarantee that function dependencies are ordered properly relative to " + + "this statement. For drops, this means you need to ensure that all functions this function depends on " + + "are dropped after this statement.", + }) + } + return []Statement{{ + DDL: fmt.Sprintf("DROP FUNCTION %s", function.GetFQEscapedName()), + Timeout: statementTimeoutDefault, + Hazards: hazards, + }}, nil +} + +func (f *functionSQLVertexGenerator) Alter(diff functionDiff) ([]Statement, error) { + if cmp.Equal(diff.old, diff.new) { + return nil, nil + } + + var hazards []MigrationHazard + if !canFunctionDependenciesBeTracked(diff.new) { + hazards = append(hazards, migrationHazardAddAlterFunctionCannotTrackDependencies) + } + return []Statement{{ + DDL: diff.new.FunctionDef, + Timeout: statementTimeoutDefault, + Hazards: hazards, + }}, nil +} + +func canFunctionDependenciesBeTracked(function schema.Function) bool { + return function.Language == "sql" +} + +func (f *functionSQLVertexGenerator) GetSQLVertexId(function schema.Function) string { + return buildFunctionVertexId(function.SchemaQualifiedName) +} + +func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFunction schema.Function) []dependency { + // Since functions can just be `CREATE OR REPLACE`, there will never be a case where a function is + // added and dropped in the same migration. Thus, we don't need a dependency on the delete vertex of a function + // because there won't be one if it is being added/altered + var deps []dependency + for _, depFunction := range newFunction.DependsOnFunctions { + deps = append(deps, mustRun(f.GetSQLVertexId(newFunction), diffTypeAddAlter).after(buildFunctionVertexId(depFunction), diffTypeAddAlter)) + } + + if !cmp.Equal(oldFunction, schema.Function{}) { + // If the function is being altered: + // If the old version of the function calls other functions that are being deleted come, those deletions + // must come after the function is altered, so it is no longer dependent on those dropped functions + for _, depFunction := range oldFunction.DependsOnFunctions { + deps = append(deps, mustRun(f.GetSQLVertexId(newFunction), diffTypeAddAlter).before(buildFunctionVertexId(depFunction), diffTypeDelete)) + } + } + + return deps +} + +func (f *functionSQLVertexGenerator) GetDeleteDependencies(function schema.Function) []dependency { + var deps []dependency + for _, depFunction := range function.DependsOnFunctions { + deps = append(deps, mustRun(f.GetSQLVertexId(function), diffTypeDelete).before(buildFunctionVertexId(depFunction), diffTypeDelete)) + } + return deps +} + +func buildFunctionVertexId(name schema.SchemaQualifiedName) string { + return buildVertexId("function", name.GetFQEscapedName()) +} + +type triggerSQLVertexGenerator struct { + // functionsInNewSchemaByName is a map of function new to functions in the new schema. + // These functions are not necessarily new + functionsInNewSchemaByName map[string]schema.Function +} + +func (t *triggerSQLVertexGenerator) Add(trigger schema.Trigger) ([]Statement, error) { + return []Statement{{ + DDL: string(trigger.GetTriggerDefStmt), + Timeout: statementTimeoutDefault, + }}, nil +} + +func (t *triggerSQLVertexGenerator) Delete(trigger schema.Trigger) ([]Statement, error) { + return []Statement{{ + DDL: fmt.Sprintf("DROP TRIGGER %s ON %s", trigger.EscapedName, trigger.OwningTable.GetFQEscapedName()), + Timeout: statementTimeoutDefault, + }}, nil +} + +func (t *triggerSQLVertexGenerator) Alter(diff triggerDiff) ([]Statement, error) { + if cmp.Equal(diff.old, diff.new) { + return nil, nil + } + + createOrReplaceStmt, err := diff.new.GetTriggerDefStmt.ToCreateOrReplace() + if err != nil { + return nil, fmt.Errorf("modifying get trigger def statement to create or replace: %w", err) + } + return []Statement{{ + DDL: createOrReplaceStmt, + Timeout: statementTimeoutDefault, + }}, nil +} + +func (t *triggerSQLVertexGenerator) GetSQLVertexId(trigger schema.Trigger) string { + return buildVertexId("trigger", trigger.GetName()) +} + +func (t *triggerSQLVertexGenerator) GetAddAlterDependencies(newTrigger, oldTrigger schema.Trigger) []dependency { + // Since a trigger can just be `CREATE OR REPLACE`, there will never be a case where a trigger is + // added and dropped in the same migration. Thus, we don't need a dependency on the delete node of a function + // because there won't be one if it is being added/altered + deps := []dependency{ + mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).after(buildFunctionVertexId(newTrigger.Function), diffTypeAddAlter), + mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).after(buildTableVertexId(newTrigger.OwningTableUnescapedName), diffTypeAddAlter), + } + + if !cmp.Equal(oldTrigger, schema.Trigger{}) { + // If the trigger is being altered: + // If the old version of the trigger called a function being deleted, the function deletion must come after the + // trigger is altered, so the trigger no longer has a dependency on the function + deps = append(deps, + mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).before(buildFunctionVertexId(oldTrigger.Function), diffTypeDelete), + ) + } + + return deps +} + +func (t *triggerSQLVertexGenerator) GetDeleteDependencies(trigger schema.Trigger) []dependency { + return []dependency{ + mustRun(t.GetSQLVertexId(trigger), diffTypeDelete).before(buildFunctionVertexId(trigger.Function), diffTypeDelete), + mustRun(t.GetSQLVertexId(trigger), diffTypeDelete).before(buildTableVertexId(trigger.OwningTableUnescapedName), diffTypeDelete), + } +} + +func buildVertexId(objType string, id string) string { + return fmt.Sprintf("%s_%s", objType, id) +} + +func stripMigrationHazards(stmts []Statement) []Statement { + var noHazardsStmts []Statement + for _, stmt := range stmts { + stmt.Hazards = nil + noHazardsStmts = append(noHazardsStmts, stmt) + } + return noHazardsStmts +} + +func dropConstraintDDL(tableName, constraintName string) string { + return fmt.Sprintf("%s DROP CONSTRAINT %s", alterTablePrefix(tableName), schema.EscapeIdentifier(constraintName)) +} + +func alterTablePrefix(tableName string) string { + return fmt.Sprintf("ALTER TABLE %s", schema.EscapeIdentifier(tableName)) +} + +func buildColumnDefinition(column schema.Column) string { + sb := strings.Builder{} + sb.WriteString(fmt.Sprintf("%s %s", schema.EscapeIdentifier(column.Name), column.Type)) + if column.IsCollated() { + sb.WriteString(fmt.Sprintf(" COLLATE %s", column.Collation.GetFQEscapedName())) + } + if !column.IsNullable { + sb.WriteString(" NOT NULL") + } + if len(column.Default) > 0 { + sb.WriteString(fmt.Sprintf(" DEFAULT %s", column.Default)) + } + return sb.String() +} + +func formattedNamesForSQL(names []string) []string { + var formattedNames []string + for _, name := range names { + formattedNames = append(formattedNames, schema.EscapeIdentifier(name)) + } + return formattedNames +} diff --git a/pkg/diff/sql_graph.go b/pkg/diff/sql_graph.go new file mode 100644 index 0000000..9088562 --- /dev/null +++ b/pkg/diff/sql_graph.go @@ -0,0 +1,66 @@ +package diff + +import ( + "fmt" + + "github.com/stripe/pg-schema-diff/internal/graph" +) + +type sqlVertex struct { + ObjId string + Statements []Statement + DiffType diffType +} + +func (s sqlVertex) GetId() string { + return fmt.Sprintf("%s_%s", s.DiffType, s.ObjId) +} + +func buildTableVertexId(name string) string { + return fmt.Sprintf("table_%s", name) +} + +func buildIndexVertexId(name string) string { + return fmt.Sprintf("index_%s", name) +} + +// sqlGraph represents two dependency webs of SQL statements +type sqlGraph graph.Graph[sqlVertex] + +// union unions the two AddsAndAlters graphs and, separately, unions the two delete graphs +func (s *sqlGraph) union(sqlGraph *sqlGraph) error { + if err := (*graph.Graph[sqlVertex])(s).Union((*graph.Graph[sqlVertex])(sqlGraph), mergeSQLVertices); err != nil { + return fmt.Errorf("unioning the graphs: %w", err) + } + return nil +} + +func mergeSQLVertices(old, new sqlVertex) sqlVertex { + return sqlVertex{ + ObjId: old.ObjId, + DiffType: old.DiffType, + Statements: append(old.Statements, new.Statements...), + } +} + +func (s *sqlGraph) toOrderedStatements() ([]Statement, error) { + vertices, err := (*graph.Graph[sqlVertex])(s).TopologicallySortWithPriority(graph.IsLowerPriorityFromGetPriority( + func(vertex sqlVertex) int { + multiplier := 1 + if vertex.DiffType == diffTypeDelete { + multiplier = -1 + } + // Prioritize adds/alters over deletes. Weight by number of statements. A 0 statement delete should be + // prioritized over a 1 statement delete + return len(vertex.Statements) * multiplier + }), + ) + if err != nil { + return nil, fmt.Errorf("topologically sorting graph: %w", err) + } + var stmts []Statement + for _, v := range vertices { + stmts = append(stmts, v.Statements...) + } + return stmts, nil +} diff --git a/pkg/diff/transform_diff.go b/pkg/diff/transform_diff.go new file mode 100644 index 0000000..98ff5ca --- /dev/null +++ b/pkg/diff/transform_diff.go @@ -0,0 +1,44 @@ +package diff + +import ( + "sort" + + "github.com/stripe/pg-schema-diff/internal/schema" +) + +// dataPackNewTables packs the columns in new tables to minimize the space they occupy +// +// Note: We need to copy all arrays we modify because otherwise those arrays (effectively pointers) +// will still exist in the original structs, leading to mutation. Go is copy-by-value, but all slices +// are pointers. If we don't copy the arrays we change, then changing a struct in the array will mutate the +// original SchemaDiff struct +func dataPackNewTables(s schemaDiff) schemaDiff { + copiedNewTables := append([]schema.Table(nil), s.tableDiffs.adds...) + for i, _ := range copiedNewTables { + copiedColumns := append([]schema.Column(nil), copiedNewTables[i].Columns...) + copiedNewTables[i].Columns = copiedColumns + sort.Slice(copiedColumns, func(i, j int) bool { + // Sort in descending order of size + return copiedColumns[i].Size > copiedColumns[j].Size + }) + } + s.tableDiffs.adds = copiedNewTables + + return s +} + +// removeChangesToColumnOrdering removes any changes to column ordering. In effect, it tells the SQL +// generator to ignore changes to column ordering +func removeChangesToColumnOrdering(s schemaDiff) schemaDiff { + copiedTableDiffs := append([]tableDiff(nil), s.tableDiffs.alters...) + for i, _ := range copiedTableDiffs { + copiedColDiffs := append([]columnDiff(nil), copiedTableDiffs[i].columnsDiff.alters...) + for i, _ := range copiedColDiffs { + copiedColDiffs[i].oldOrdering = copiedColDiffs[i].newOrdering + } + copiedTableDiffs[i].columnsDiff.alters = copiedColDiffs + } + s.tableDiffs.alters = copiedTableDiffs + + return s +} diff --git a/pkg/diff/transform_diff_test.go b/pkg/diff/transform_diff_test.go new file mode 100644 index 0000000..8c509b8 --- /dev/null +++ b/pkg/diff/transform_diff_test.go @@ -0,0 +1,354 @@ +package diff + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stripe/pg-schema-diff/internal/schema" +) + +func TestTransformDiffDataPackNewTables(t *testing.T) { + tcs := []transformDiffTestCase{ + { + name: "No new tables", + in: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{}, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + expectedOut: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{}, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + }, + { + name: "New table with no columns", + in: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{ + adds: []schema.Table{ + { + Name: "foobar", + Columns: nil, + CheckConstraints: nil, + }, + }, + }, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + expectedOut: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{ + adds: []schema.Table{ + { + Name: "foobar", + Columns: nil, + CheckConstraints: nil, + }, + }, + }, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + }, + { + name: "Standard", + in: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{ + alters: []tableDiff{ + buildTableDiffWithColDiffs("foo", []columnDiff{ + buildColumnDiff(schema.Column{Name: "genre", Type: "some type", Size: 3}, 0, 0), + buildColumnDiff(schema.Column{Name: "content", Type: "some type", Size: 2}, 1, 1), + buildColumnDiff(schema.Column{Name: "title", Type: "some type", Size: 10}, 2, 2), + }), + }, + adds: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "coffee", Type: "some type", Size: 3}, + {Name: "mocha", Type: "some type", Size: 2}, + {Name: "latte", Type: "some type", Size: 10}, + }, + CheckConstraints: nil, + }, + { + Name: "baz", + Columns: []schema.Column{ + {Name: "dog", Type: "some type", Size: 1}, + {Name: "cat", Type: "some type", Size: 2}, + {Name: "rabbit", Type: "some type", Size: 3}, + }, + CheckConstraints: nil, + }, + }, + deletes: []schema.Table{ + { + Name: "fizz", + Columns: []schema.Column{ + {Name: "croissant", Type: "some type", Size: 3}, + {Name: "bagel", Type: "some type", Size: 2}, + {Name: "pastry", Type: "some type", Size: 10}, + }, + CheckConstraints: nil, + }, + }, + }, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + expectedOut: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{ + alters: []tableDiff{ + buildTableDiffWithColDiffs("foo", []columnDiff{ + buildColumnDiff(schema.Column{Name: "genre", Type: "some type", Size: 3}, 0, 0), + buildColumnDiff(schema.Column{Name: "content", Type: "some type", Size: 2}, 1, 1), + buildColumnDiff(schema.Column{Name: "title", Type: "some type", Size: 10}, 2, 2), + }), + }, + adds: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "latte", Type: "some type", Size: 10}, + {Name: "coffee", Type: "some type", Size: 3}, + {Name: "mocha", Type: "some type", Size: 2}, + }, + CheckConstraints: nil, + }, + { + Name: "baz", + Columns: []schema.Column{ + {Name: "rabbit", Type: "some type", Size: 3}, + {Name: "cat", Type: "some type", Size: 2}, + {Name: "dog", Type: "some type", Size: 1}, + }, + CheckConstraints: nil, + }, + }, + deletes: []schema.Table{ + { + Name: "fizz", + Columns: []schema.Column{ + {Name: "croissant", Type: "some type", Size: 3}, + {Name: "bagel", Type: "some type", Size: 2}, + {Name: "pastry", Type: "some type", Size: 10}, + }, + CheckConstraints: nil, + }, + }, + }, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + }, + } + runTestCases(t, dataPackNewTables, tcs) +} + +func TestTransformDiffRemoveChangesToColumnOrdering(t *testing.T) { + tcs := []transformDiffTestCase{ + { + name: "No altered tables", + in: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{}, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + expectedOut: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{}, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + }, + { + name: "Altered table with no columns", + in: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{ + alters: []tableDiff{buildTableDiffWithColDiffs("foobar", nil)}, + }, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + expectedOut: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{ + alters: []tableDiff{buildTableDiffWithColDiffs("foobar", nil)}, + }, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + }, + { + name: "Standard", + in: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{ + alters: []tableDiff{ + buildTableDiffWithColDiffs("foo", []columnDiff{ + buildColumnDiff(schema.Column{Name: "genre", Type: "some type", Size: 3}, 0, 1), + buildColumnDiff(schema.Column{Name: "content", Type: "some type", Size: 2}, 1, 2), + buildColumnDiff(schema.Column{Name: "title", Type: "some type", Size: 10}, 2, 0), + }), + buildTableDiffWithColDiffs("bar", []columnDiff{ + buildColumnDiff(schema.Column{Name: "item", Type: "some type", Size: 3}, 0, 2), + buildColumnDiff(schema.Column{Name: "type", Type: "some type", Size: 2}, 1, 1), + buildColumnDiff(schema.Column{Name: "color", Type: "some type", Size: 10}, 2, 0), + }), + }, + adds: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "cold brew", Type: "some type", Size: 2}, + {Name: "coffee", Type: "some type", Size: 3}, + {Name: "mocha", Type: "some type", Size: 2}, + {Name: "latte", Type: "some type", Size: 10}, + }, + CheckConstraints: nil, + }, + }, + deletes: []schema.Table{ + { + Name: "fizz", + Columns: []schema.Column{ + {Name: "croissant", Type: "some type", Size: 3}, + {Name: "bagel", Type: "some type", Size: 2}, + {Name: "pastry", Type: "some type", Size: 10}, + }, + CheckConstraints: nil, + }, + }, + }, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + expectedOut: schemaDiff{ + oldAndNew: oldAndNew[schema.Schema]{}, + tableDiffs: listDiff[schema.Table, tableDiff]{ + alters: []tableDiff{ + buildTableDiffWithColDiffs("foo", []columnDiff{ + buildColumnDiff(schema.Column{Name: "genre", Type: "some type", Size: 3}, 1, 1), + buildColumnDiff(schema.Column{Name: "content", Type: "some type", Size: 2}, 2, 2), + buildColumnDiff(schema.Column{Name: "title", Type: "some type", Size: 10}, 0, 0), + }), + buildTableDiffWithColDiffs("bar", []columnDiff{ + buildColumnDiff(schema.Column{Name: "item", Type: "some type", Size: 3}, 2, 2), + buildColumnDiff(schema.Column{Name: "type", Type: "some type", Size: 2}, 1, 1), + buildColumnDiff(schema.Column{Name: "color", Type: "some type", Size: 10}, 0, 0), + }), + }, + adds: []schema.Table{ + { + Name: "foobar", + Columns: []schema.Column{ + {Name: "cold brew", Type: "some type", Size: 2}, + {Name: "coffee", Type: "some type", Size: 3}, + {Name: "mocha", Type: "some type", Size: 2}, + {Name: "latte", Type: "some type", Size: 10}, + }, + CheckConstraints: nil, + }, + }, + deletes: []schema.Table{ + { + Name: "fizz", + Columns: []schema.Column{ + {Name: "croissant", Type: "some type", Size: 3}, + {Name: "bagel", Type: "some type", Size: 2}, + {Name: "pastry", Type: "some type", Size: 10}, + }, + CheckConstraints: nil, + }, + }, + }, + indexDiffs: listDiff[schema.Index, indexDiff]{}, + functionDiffs: listDiff[schema.Function, functionDiff]{}, + triggerDiffs: listDiff[schema.Trigger, triggerDiff]{}, + }, + }, + } + runTestCases(t, removeChangesToColumnOrdering, tcs) +} + +type transformDiffTestCase struct { + name string + in schemaDiff + expectedOut schemaDiff +} + +func runTestCases(t *testing.T, transform func(diff schemaDiff) schemaDiff, tcs []transformDiffTestCase) { + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expectedOut, transform(tc.in)) + }) + } +} + +func buildTableDiffWithColDiffs(name string, columnDiffs []columnDiff) tableDiff { + var oldColumns, newColumns []schema.Column + copiedColumnDiffs := append([]columnDiff(nil), columnDiffs...) + sort.Slice(copiedColumnDiffs, func(i, j int) bool { + return copiedColumnDiffs[i].oldOrdering < copiedColumnDiffs[j].oldOrdering + }) + for _, colDiff := range copiedColumnDiffs { + oldColumns = append(oldColumns, colDiff.old) + } + sort.Slice(copiedColumnDiffs, func(i, j int) bool { + return copiedColumnDiffs[i].newOrdering < copiedColumnDiffs[j].newOrdering + }) + for _, colDiff := range copiedColumnDiffs { + oldColumns = append(newColumns, colDiff.new) + } + + return tableDiff{ + oldAndNew: oldAndNew[schema.Table]{ + old: schema.Table{ + Name: name, + Columns: oldColumns, + CheckConstraints: nil, + }, + new: schema.Table{ + Name: name, + Columns: newColumns, + CheckConstraints: nil, + }, + }, + columnsDiff: listDiff[schema.Column, columnDiff]{ + alters: columnDiffs, + }, + checkConstraintDiff: listDiff[schema.CheckConstraint, checkConstraintDiff]{}, + } +} + +func buildColumnDiff(col schema.Column, oldOrdering, newOrdering int) columnDiff { + return columnDiff{ + oldAndNew: oldAndNew[schema.Column]{ + old: col, + new: col, + }, + oldOrdering: oldOrdering, + newOrdering: newOrdering, + } +} diff --git a/pkg/log/logger.go b/pkg/log/logger.go new file mode 100644 index 0000000..3c1a2a5 --- /dev/null +++ b/pkg/log/logger.go @@ -0,0 +1,24 @@ +package log + +import ( + "fmt" + "log" +) + +type ( + Logger interface { + Errorf(msg string, args ...any) + } + + simpleLogger struct{} +) + +// SimpleLogger is a bare-bones implementation of the logging interface, e.g., used for testing +func SimpleLogger() Logger { + return &simpleLogger{} +} + +func (*simpleLogger) Errorf(msg string, args ...any) { + formattedMessage := fmt.Sprintf(msg, args...) + log.Println(fmt.Sprintf("[ERROR] %s", formattedMessage)) +} diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go new file mode 100644 index 0000000..f004070 --- /dev/null +++ b/pkg/schema/schema.go @@ -0,0 +1,25 @@ +package schema + +import ( + "context" + "database/sql" + "fmt" + + internalschema "github.com/stripe/pg-schema-diff/internal/schema" +) + +// GetPublicSchemaHash hash gets the hash of the "public" schema. It can be used to compare against the hash in the migration +// plan to determine if it's still valid +// We do not expose the Schema struct yet because it is subject to change, and we do not want folks depending on its API +func GetPublicSchemaHash(ctx context.Context, conn *sql.Conn) (string, error) { + schema, err := internalschema.GetPublicSchema(ctx, conn) + if err != nil { + return "", fmt.Errorf("getting public schema: %w", err) + } + hash, err := schema.Hash() + if err != nil { + return "", fmt.Errorf("hashing schema: %w", err) + } + + return hash, nil +} diff --git a/pkg/schema/schema_test.go b/pkg/schema/schema_test.go new file mode 100644 index 0000000..577c975 --- /dev/null +++ b/pkg/schema/schema_test.go @@ -0,0 +1,101 @@ +package schema_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/stripe/pg-schema-diff/internal/pgengine" + "github.com/stripe/pg-schema-diff/pkg/schema" +) + +type schemaTestSuite struct { + suite.Suite + pgEngine *pgengine.Engine +} + +func (suite *schemaTestSuite) SetupSuite() { + engine, err := pgengine.StartEngine() + suite.Require().NoError(err) + suite.pgEngine = engine +} + +func (suite *schemaTestSuite) TearDownSuite() { + suite.pgEngine.Close() +} + +func (suite *schemaTestSuite) TestGetPublicSchemaHash() { + const ( + ddl = ` + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + + CREATE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; + $$ LANGUAGE plpgsql; + + CREATE FUNCTION function_with_dependencies(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN add(a, b) + increment(a); + + CREATE TABLE foo ( + id INTEGER PRIMARY KEY, + author TEXT COLLATE "C", + content TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP CHECK (created_at > CURRENT_TIMESTAMP - interval '1 month') NO INHERIT, + version INT NOT NULL DEFAULT 0, + CHECK ( function_with_dependencies(id, id) > 0) + ); + + ALTER TABLE foo ADD CONSTRAINT author_check CHECK (author IS NOT NULL AND LENGTH(author) > 0) NO INHERIT NOT VALID; + CREATE INDEX some_idx ON foo USING hash (content); + CREATE UNIQUE INDEX some_unique_idx ON foo (created_at DESC, author ASC); + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ + BEGIN + NEW.version = OLD.version + 1; + RETURN NEW; + END; + $$ language 'plpgsql'; + + CREATE TRIGGER some_trigger + BEFORE UPDATE ON foo + FOR EACH ROW + WHEN (OLD.* IS DISTINCT FROM NEW.*) + EXECUTE PROCEDURE increment_version(); + ` + + expectedHash = "9648c294aed76ef6" + ) + db, err := suite.pgEngine.CreateDatabase() + suite.Require().NoError(err) + defer db.DropDB() + + connPool, err := sql.Open("pgx", db.GetDSN()) + suite.Require().NoError(err) + defer connPool.Close() + + _, err = connPool.ExecContext(context.Background(), ddl) + suite.Require().NoError(err) + + conn, err := connPool.Conn(context.Background()) + suite.Require().NoError(err) + defer conn.Close() + + hash, err := schema.GetPublicSchemaHash(context.Background(), conn) + suite.Require().NoError(err) + + suite.Equal(expectedHash, hash) +} + +func TestSchemaTestSuite(t *testing.T) { + suite.Run(t, new(schemaTestSuite)) +} diff --git a/pkg/tempdb/factory.go b/pkg/tempdb/factory.go new file mode 100644 index 0000000..e502bb4 --- /dev/null +++ b/pkg/tempdb/factory.go @@ -0,0 +1,208 @@ +package tempdb + +import ( + "context" + "database/sql" + "fmt" + "io" + "strings" + + "github.com/google/uuid" + "github.com/jackc/pgx/v4" + _ "github.com/jackc/pgx/v4/stdlib" + "github.com/stripe/pg-schema-diff/internal/pgidentifier" + "github.com/stripe/pg-schema-diff/pkg/log" +) + +const ( + DefaultOnInstanceDbPrefix = "pgschemadifftmp_" + DefaultOnInstanceMetadataSchema = "pgschemadifftmp_metadata" + DefaultOnInstanceMetadataTable = "metadata" +) + +// Factory is used to create temp databases These databases do not have to be in-memory. They might be, for example, +// be created on the target Postgres server +type ( + Dropper func(ctx context.Context) error + + Factory interface { + // Create creates a temporary database. Be sure to always call the Dropper to ensure the database and + // connections are cleaned up + Create(ctx context.Context) (db *sql.DB, dropper Dropper, err error) + + io.Closer + } +) + +type ( + onInstanceFactoryOptions struct { + dbPrefix string + metadataSchema string + metadataTable string + logger log.Logger + } + + OnInstanceFactoryOpt func(*onInstanceFactoryOptions) +) + +// WithLogger sets the logger for the factory. If not set, a SimpleLogger will be used +func WithLogger(logger log.Logger) OnInstanceFactoryOpt { + return func(opts *onInstanceFactoryOptions) { + opts.logger = logger + } +} + +// WithDbPrefix sets the prefix for the temp database name +func WithDbPrefix(prefix string) OnInstanceFactoryOpt { + return func(opts *onInstanceFactoryOptions) { + opts.dbPrefix = prefix + } +} + +// WithMetadataSchema sets the prefix for the schema name containing the metadata +func WithMetadataSchema(schema string) OnInstanceFactoryOpt { + return func(opts *onInstanceFactoryOptions) { + opts.metadataSchema = schema + } +} + +// WithMetadataTable sets the metadata table name +func WithMetadataTable(table string) OnInstanceFactoryOpt { + return func(opts *onInstanceFactoryOptions) { + opts.metadataTable = table + } +} + +type ( + CreateConnForDbFn func(ctx context.Context, dbName string) (*sql.DB, error) + + // onInstanceFactory creates temporary databases on the provided Postgres server + onInstanceFactory struct { + rootDb *sql.DB + createConnForDb CreateConnForDbFn + options onInstanceFactoryOptions + } +) + +// NewOnInstanceFactory provides an implementation to easily create temporary databases on the Postgres instance +// connected to via CreateConnForDbFn. The Postgres instance is connected to via the "postgres" database, and then +// temporary databases are created using that connection. These temporary databases are also connected to via the +// CreateConnForDbFn. +// Make sure to always call Close() on the returned Factory to ensure the root connection is closed +// +// WARNING: +// It is possible this will lead to orphaned temporary databases. These orphaned databases should be pretty small if +// they're only being used by the pg-schema-diff library, but it's recommended to clean them up when possible. This can +// be done by deleting all old databases with the provided temp db prefix. The metadata table can be inspected to find +// when the temporary database was created, e.g., to create a TTL +func NewOnInstanceFactory(ctx context.Context, createConnForDb CreateConnForDbFn, opts ...OnInstanceFactoryOpt) (Factory, error) { + options := onInstanceFactoryOptions{ + dbPrefix: DefaultOnInstanceDbPrefix, + metadataSchema: DefaultOnInstanceMetadataSchema, + metadataTable: DefaultOnInstanceMetadataTable, + logger: log.SimpleLogger(), + } + for _, opt := range opts { + opt(&options) + } + if !pgidentifier.IsSimpleIdentifier(options.dbPrefix) { + return nil, fmt.Errorf("dbPrefix (%s) must be a simple Postgres identifier matching the following regex: %s", options.dbPrefix, pgidentifier.SimpleIdentifierRegex) + } + + rootDb, err := createConnForDb(ctx, "postgres") + if err != nil { + return &onInstanceFactory{}, err + } + if err := assertConnPoolIsOnExpectedDatabase(ctx, rootDb, "postgres"); err != nil { + rootDb.Close() + return &onInstanceFactory{}, fmt.Errorf("assertConnPoolIsOnExpectedDatabase: %w", err) + } + + return &onInstanceFactory{ + rootDb: rootDb, + createConnForDb: createConnForDb, + options: options, + }, nil +} + +func (o *onInstanceFactory) Close() error { + return o.rootDb.Close() +} + +func (o *onInstanceFactory) Create(ctx context.Context) (sql *sql.DB, dropper Dropper, retErr error) { + dbUUID, err := uuid.NewUUID() + if err != nil { + return nil, nil, err + } + tempDbName := o.options.dbPrefix + strings.ReplaceAll(dbUUID.String(), "-", "_") + if _, err = o.rootDb.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s;", tempDbName)); err != nil { + return nil, nil, err + } + defer func() { + // Only drop the temp database if an error occurred during creation + if retErr != nil { + if err := o.dropTempDatabase(ctx, tempDbName); err != nil { + o.options.logger.Errorf("Failed to drop temporary database %s because of error %s. This drop was automatically triggered by error %s", tempDbName, err.Error(), retErr.Error()) + } + } + }() + + tempDbConn, err := o.createConnForDb(ctx, tempDbName) + if err != nil { + return nil, nil, err + } + defer func() { + // Only close the connection pool if an error occurred during creation. + // We should close the connection pool on the off-chance that the drop database fails + if retErr != nil { + _ = tempDbConn.Close() + } + }() + if err := assertConnPoolIsOnExpectedDatabase(ctx, tempDbConn, tempDbName); err != nil { + return nil, nil, fmt.Errorf("assertConnPoolIsOnExpectedDatabase: %w", err) + } + + // There's no easy way to keep track of when a database was created, so + // create a row with the temp database's time of creation. This will be used + // by a cleanup process to drop any temporary databases that were not cleaned up + // successfully by the "dropper" + sanitizedSchemaName := pgx.Identifier{o.options.metadataSchema}.Sanitize() + sanitizedTableName := pgx.Identifier{o.options.metadataSchema, o.options.metadataTable}.Sanitize() + createMetadataStmts := fmt.Sprintf(` + CREATE SCHEMA %s + CREATE TABLE %s( + db_created_at TIMESTAMPTZ NOT NULL DEFAULT current_timestamp + ); + INSERT INTO %s DEFAULT VALUES; + `, sanitizedSchemaName, sanitizedTableName, sanitizedTableName) + if _, err := tempDbConn.ExecContext(ctx, createMetadataStmts); err != nil { + return nil, nil, err + } + + return tempDbConn, func(ctx context.Context) error { + _ = tempDbConn.Close() + return o.dropTempDatabase(ctx, tempDbName) + }, nil + +} + +// assertConnPoolIsOnExpectedDatabase provides validation that a user properly passed in a proper CreateConnForDbFn +func assertConnPoolIsOnExpectedDatabase(ctx context.Context, connPool *sql.DB, expectedDatabase string) error { + var dbName string + if err := connPool.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbName); err != nil { + return err + } + if dbName != expectedDatabase { + return fmt.Errorf("connection pool is on database %s, expected %s", dbName, expectedDatabase) + } + + return nil +} + +func (o *onInstanceFactory) dropTempDatabase(ctx context.Context, dbName string) error { + if !strings.HasPrefix(dbName, o.options.dbPrefix) { + return fmt.Errorf("drop non-temporary database: %s", dbName) + } + _, err := o.rootDb.ExecContext(ctx, fmt.Sprintf("DROP DATABASE %s;", dbName)) + return err +} diff --git a/pkg/tempdb/factory_test.go b/pkg/tempdb/factory_test.go new file mode 100644 index 0000000..f02fddb --- /dev/null +++ b/pkg/tempdb/factory_test.go @@ -0,0 +1,190 @@ +package tempdb + +import ( + "context" + "database/sql" + "fmt" + "strings" + "testing" + "time" + + _ "github.com/jackc/pgx/v4/stdlib" + "github.com/stretchr/testify/suite" + "github.com/stripe/pg-schema-diff/internal/pgengine" + + "github.com/stripe/pg-schema-diff/pkg/log" +) + +type onInstanceTempDbFactorySuite struct { + suite.Suite + + engine *pgengine.Engine +} + +func (suite *onInstanceTempDbFactorySuite) SetupSuite() { + engine, err := pgengine.StartEngine() + suite.Require().NoError(err) + suite.engine = engine +} + +func (suite *onInstanceTempDbFactorySuite) TearDownSuite() { + suite.engine.Close() +} + +func (suite *onInstanceTempDbFactorySuite) mustBuildFactory(opt ...OnInstanceFactoryOpt) Factory { + factory, err := NewOnInstanceFactory(context.Background(), func(ctx context.Context, dbName string) (*sql.DB, error) { + return suite.getConnPoolForDb(dbName) + }, opt...) + suite.Require().NoError(err) + return factory +} + +func (suite *onInstanceTempDbFactorySuite) getConnPoolForDb(dbName string) (*sql.DB, error) { + return sql.Open("pgx", suite.engine.GetPostgresDatabaseConnOpts().With("dbname", dbName).ToDSN()) +} + +func (suite *onInstanceTempDbFactorySuite) mustRunSQL(conn *sql.Conn) { + _, err := conn.ExecContext(context.Background(), ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + message TEXT + ); + CREATE INDEX message_idx ON foobar(message); + `) + suite.Require().NoError(err) + + _, err = conn.ExecContext(context.Background(), ` + INSERT INTO foobar VALUES (1, 'some message'), (2, 'some other message'), (3, 'a final message'); + `) + suite.Require().NoError(err) + + res, err := conn.QueryContext(context.Background(), ` + SELECT id, message FROM foobar; + `) + suite.Require().NoError(err) + + var rows [][]any + for res.Next() { + var id int + var message string + suite.Require().NoError(res.Scan(&id, &message)) + rows = append(rows, []any{ + id, message, + }) + } + suite.ElementsMatch([][]any{ + {1, "some message"}, + {2, "some other message"}, + {3, "a final message"}, + }, rows) +} + +func (suite *onInstanceTempDbFactorySuite) TestNew_ConnectsToWrongDatabase() { + db, err := suite.engine.CreateDatabaseWithName("not-postgres") + suite.Require().NoError(err) + defer db.DropDB() + + _, err = NewOnInstanceFactory(context.Background(), func(ctx context.Context, dbName string) (*sql.DB, error) { + return suite.getConnPoolForDb("not-postgres") + }) + suite.ErrorContains(err, "connection pool is on") +} + +func (suite *onInstanceTempDbFactorySuite) TestNew_ErrorsOnNonSimpleDbPrefix() { + _, err := NewOnInstanceFactory(context.Background(), func(ctx context.Context, dbName string) (*sql.DB, error) { + suite.Fail("shouldn't be reached") + return nil, nil + }, WithDbPrefix("non-simple identifier")) + suite.ErrorContains(err, "must be a simple Postgres identifier") +} + +func (suite *onInstanceTempDbFactorySuite) TestCreate_CreateAndDropFlow() { + const ( + dbPrefix = "some_prefix" + metadataSchema = "some metadata schema" + metadataTable = "some metadata table" + ) + factory := suite.mustBuildFactory( + WithDbPrefix(dbPrefix), + WithMetadataSchema(metadataSchema), + WithMetadataTable(metadataTable), + WithLogger(log.SimpleLogger()), + ) + defer func(factory Factory) { + suite.Require().NoError(factory.Close()) + }(factory) + + db, dropper, err := factory.Create(context.Background()) + suite.Require().NoError(err) + // don't defer dropping. we want to run assertions after it drops. if dropping fails, + // it shouldn't be a problem because names shouldn't conflict + afterTimeOfCreation := time.Now() + + conn1, err := db.Conn(context.Background()) + suite.Require().NoError(err) + + var dbName string + suite.Require().NoError(conn1.QueryRowContext(context.Background(), "SELECT current_database()").Scan(&dbName)) + suite.True(strings.HasPrefix(dbName, dbPrefix)) + suite.Len(dbName, len(dbPrefix)+36) // should be length of prefix + length of uuid + + // Make sure SQL can run on the connection + suite.mustRunSQL(conn1) + + // check the metadata entry exists + var createdAt time.Time + metadataQuery := fmt.Sprintf(` + SELECT * FROM "%s"."%s" + `, metadataSchema, metadataTable) + suite.Require().NoError(conn1.QueryRowContext(context.Background(), metadataQuery).Scan(&createdAt)) + suite.True(createdAt.Before(afterTimeOfCreation)) + + // get another connection from the pool and make sure it's also set to the correct db while + // the other connection is still open + conn2, err := db.Conn(context.Background()) + suite.Require().NoError(err) + var dbNameFromConn2 string + suite.Require().NoError(conn2.QueryRowContext(context.Background(), "SELECT current_database()").Scan(&dbNameFromConn2)) + suite.Equal(dbName, dbNameFromConn2) + + suite.Require().NoError(conn1.Close()) + suite.Require().NoError(conn2.Close()) + + // drop database + suite.Require().NoError(db.Close()) + suite.Require().NoError(dropper(context.Background())) + + // expect an error when attempting to query the database, since it should be dropped. + // when a db pool is opened, it has no connections. + // a query is needed in order to find if the database still exists. + conn, err := suite.getConnPoolForDb(dbName) + suite.Require().NoError(err) + suite.Require().ErrorContains(conn.QueryRowContext(context.Background(), metadataQuery).Scan(&createdAt), "SQLSTATE 3D000") + suite.True(createdAt.Before(afterTimeOfCreation)) +} + +func (suite *onInstanceTempDbFactorySuite) TestCreate_ConnectsToWrongDatabase() { + factory, err := NewOnInstanceFactory(context.Background(), func(ctx context.Context, dbName string) (*sql.DB, error) { + return suite.getConnPoolForDb("postgres") + }) + suite.Require().NoError(err) + defer func(factory Factory) { + suite.Require().NoError(factory.Close()) + }(factory) + + _, _, err = factory.Create(context.Background()) + suite.ErrorContains(err, "connection pool is on") +} + +func (suite *onInstanceTempDbFactorySuite) TestDropTempDB_CannotDropNonTempDb() { + factory := suite.mustBuildFactory() + defer func(factory Factory) { + suite.Require().NoError(factory.Close()) + }(factory) + + suite.ErrorContains(factory.(*onInstanceFactory).dropTempDatabase(context.Background(), "some_db"), "drop non-temporary database") +} + +func TestOnInstanceFactorySuite(t *testing.T) { + suite.Run(t, new(onInstanceTempDbFactorySuite)) +} diff --git a/scripts/codegen.sh b/scripts/codegen.sh new file mode 100755 index 0000000..95af2f3 --- /dev/null +++ b/scripts/codegen.sh @@ -0,0 +1,4 @@ +#!/bin/sh + +docker build -t pg-schema-diff-code-gen-runner -f ./build/Dockerfile.codegen . +docker run --rm -v $(pwd):/pg-schema-diff -w /pg-schema-diff pg-schema-diff-code-gen-runner