diff --git a/internal/migration_acceptance_tests/column_cases_test.go b/internal/migration_acceptance_tests/column_cases_test.go index 3425c86..5cf0809 100644 --- a/internal/migration_acceptance_tests/column_cases_test.go +++ b/internal/migration_acceptance_tests/column_cases_test.go @@ -145,6 +145,42 @@ var columnAcceptanceTestCases = []acceptanceTestCase{ ) `}, }, + { + name: "Add identity column - always no cycle", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + }, + { + name: "Add identity column - default cycle", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY, + identity_always BIGINT GENERATED BY DEFAULT AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 CYCLE ) + ); + `, + }, + }, { name: "Delete one column", oldSchemaDDL: []string{ @@ -925,6 +961,210 @@ var columnAcceptanceTestCases = []acceptanceTestCase{ diff.MigrationHazardTypeImpactsDatabasePerformance, }, }, + { + name: "Remove identity from column", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT + ); + `, + }, + }, + { + name: "Add identity to column", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY + ); + `, + }, + }, + { + name: "Alter identity type - to by default", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED BY DEFAULT AS IDENTITY ( MINVALUE 1 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + }, + { + name: "Alter identity type - to always", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED BY DEFAULT AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 1 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + }, + { + name: "Alter identity minvalue", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 1 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + }, + { + name: "Alter identity maxvalue", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 90 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + }, + { + name: "Alter identity start", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 6 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + }, + { + name: "Alter identity increment", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 6 CACHE 5 NO CYCLE ) + ); + `, + }, + }, + { + name: "Alter identity cache", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 6 NO CYCLE ) + ); + `, + }, + }, + { + name: "Alter identity cycle - to cycle", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 CYCLE ) + ); + `, + }, + }, + { + name: "Alter identity cycle - to no cycle", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 CYCLE ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + }, + { + name: "Alter all identity properties (from always to default, from no cycle to cycle)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ) + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + identity_always BIGINT GENERATED BY DEFAULT AS IDENTITY ( MINVALUE 1 MAXVALUE 90 START 30 INCREMENT 40 CACHE 50 CYCLE ) + ); + `, + }, + }, } func (suite *acceptanceTestSuite) TestColumnTestCases() { diff --git a/internal/migration_acceptance_tests/partitioned_table_cases_test.go b/internal/migration_acceptance_tests/partitioned_table_cases_test.go index 0ef22d3..556142b 100644 --- a/internal/migration_acceptance_tests/partitioned_table_cases_test.go +++ b/internal/migration_acceptance_tests/partitioned_table_cases_test.go @@ -14,6 +14,7 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ foo VARCHAR(255), bar TEXT COLLATE "POSIX", fizz SERIAL, + identity BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ), CHECK ( fizz > 0 ), PRIMARY KEY (foo, id), UNIQUE (foo, bar) @@ -58,6 +59,7 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ foo VARCHAR(255), bar TEXT COLLATE "POSIX", fizz SERIAL, + identity BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ), CHECK ( fizz > 0 ), PRIMARY KEY (foo, id), UNIQUE (foo, bar) @@ -114,6 +116,7 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ foo VARCHAR(255), bar TEXT COLLATE "POSIX" NOT NULL DEFAULT 'some default', fizz SERIAL, + identity BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ), CHECK ( fizz > 0 ), PRIMARY KEY (foo, id), UNIQUE (foo, bar) @@ -140,40 +143,6 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ CREATE UNIQUE INDEX foobar_2_local_unique_idx ON schema_2.foobar_2(foo); `, }, - - expectedDBSchemaDDL: []string{` - CREATE SCHEMA schema_1; - CREATE TABLE schema_1."Foobar"( - id INT, - foo VARCHAR(255), - bar TEXT COLLATE "POSIX" NOT NULL DEFAULT 'some default', - fizz SERIAL, - CHECK ( fizz > 0 ), - PRIMARY KEY (foo, id), - UNIQUE (foo, bar) - ) PARTITION BY LIST (foo); - ALTER TABLE schema_1."Foobar" REPLICA IDENTITY FULL; - - ALTER TABLE schema_1."Foobar" ENABLE ROW LEVEL SECURITY; - ALTER TABLE schema_1."Foobar" FORCE ROW LEVEL SECURITY; - - -- partitions - CREATE SCHEMA schema_2; - CREATE TABLE schema_2."FOOBAR_1" PARTITION OF schema_1."Foobar"( - foo NOT NULL, - bar NOT NULL - ) FOR VALUES IN ('foo_1'); - ALTER TABLE schema_2."FOOBAR_1" REPLICA IDENTITY NOTHING ; - CREATE TABLE schema_2.foobar_2 PARTITION OF schema_1."Foobar" FOR VALUES IN ('foo_2'); - ALTER TABLE schema_2.foobar_2 REPLICA IDENTITY FULL; - CREATE TABLE schema_2.foobar_3 PARTITION OF schema_1."Foobar" FOR VALUES IN ('foo_3'); - -- partitioned indexes - CREATE UNIQUE INDEX foobar_unique_idx ON schema_1."Foobar"(foo, fizz); - -- local indexes - CREATE INDEX foobar_1_local_idx ON schema_2."FOOBAR_1"(foo); - CREATE UNIQUE INDEX foobar_2_local_unique_idx ON schema_2.foobar_2(foo); - `, - }, }, { name: "Create partitioned table with local primary keys and RLS enabled locally", diff --git a/internal/migration_acceptance_tests/schema_cases_test.go b/internal/migration_acceptance_tests/schema_cases_test.go index 9a85cf3..5bb0eb8 100644 --- a/internal/migration_acceptance_tests/schema_cases_test.go +++ b/internal/migration_acceptance_tests/schema_cases_test.go @@ -49,6 +49,7 @@ var schemaAcceptanceTests = []acceptanceTestCase{ bar SERIAL NOT NULL, fizz TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, color schema_1.color DEFAULT 'green', + quux BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ), PRIMARY KEY (foo, id), UNIQUE (foo, bar) ) PARTITION BY LIST(foo); @@ -123,6 +124,7 @@ var schemaAcceptanceTests = []acceptanceTestCase{ bar SERIAL NOT NULL, fizz TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, color schema_1.color DEFAULT 'green', + quux BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ), PRIMARY KEY (foo, id), UNIQUE (foo, bar) ) PARTITION BY LIST(foo); @@ -211,6 +213,7 @@ var schemaAcceptanceTests = []acceptanceTestCase{ foo VARCHAR(255) DEFAULT 'some default' NOT NULL, fizz TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, color schema_1.color DEFAULT 'green', + quux BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ), UNIQUE (foo, bar) ); CREATE INDEX foobar_normal_idx ON foobar USING hash (fizz); @@ -292,6 +295,7 @@ var schemaAcceptanceTests = []acceptanceTestCase{ new_color new_color DEFAULT 'cyan', new_bar SMALLSERIAL NOT NULL, new_foo VARCHAR(255) DEFAULT '' NOT NULL CHECK ( new_foo IS NOT NULL), + new_quux BIGINT GENERATED BY DEFAULT AS IDENTITY ( MINVALUE 20 MAXVALUE 90 START 30 INCREMENT 40 CACHE 50 CYCLE ), UNIQUE (new_foo, new_bar) ); ALTER TABLE "New_table" ADD CONSTRAINT "new_fzz_check" CHECK ( new_fizz < CURRENT_TIMESTAMP - interval '1 month' ) NO INHERIT NOT VALID; diff --git a/internal/migration_acceptance_tests/table_cases_test.go b/internal/migration_acceptance_tests/table_cases_test.go index 8dc89bf..9b2e5f0 100644 --- a/internal/migration_acceptance_tests/table_cases_test.go +++ b/internal/migration_acceptance_tests/table_cases_test.go @@ -14,7 +14,9 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL, bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, fizz SERIAL NOT NULL UNIQUE , - buzz REAL CHECK (buzz IS NOT NULL) + buzz REAL CHECK (buzz IS NOT NULL), + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ), + identity_default BIGINT GENERATED BY DEFAULT AS IDENTITY ( MINVALUE 20 MAXVALUE 90 START 30 INCREMENT 40 CACHE 50 CYCLE ) ); ALTER TABLE foobar REPLICA IDENTITY FULL; ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; @@ -39,7 +41,9 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL, bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, fizz SERIAL NOT NULL UNIQUE, - buzz REAL CHECK (buzz IS NOT NULL) + buzz REAL CHECK (buzz IS NOT NULL), + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ), + identity_default BIGINT GENERATED BY DEFAULT AS IDENTITY ( MINVALUE 20 MAXVALUE 90 START 30 INCREMENT 40 CACHE 50 CYCLE ) ); ALTER TABLE foobar REPLICA IDENTITY FULL; ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; @@ -70,7 +74,9 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL, bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, fizz SERIAL NOT NULL UNIQUE, - buzz REAL CHECK (buzz IS NOT NULL) + buzz REAL CHECK (buzz IS NOT NULL), + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ), + identity_default BIGINT GENERATED BY DEFAULT AS IDENTITY ( MINVALUE 20 MAXVALUE 90 START 30 INCREMENT 40 CACHE 50 CYCLE ) ); ALTER TABLE foobar REPLICA IDENTITY FULL; ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql index aa50615..8a1ea69 100644 --- a/internal/queries/queries.sql +++ b/internal/queries/queries.sql @@ -54,6 +54,30 @@ WHERE AND (c.relkind = 'r' OR c.relkind = 'p'); -- name: GetColumnsForTable :many +WITH identity_col_seq AS ( + SELECT + depend.refobjid AS owner_relid, + depend.refobjsubid AS owner_attnum, + pg_seq.seqstart, + pg_seq.seqincrement, + pg_seq.seqmax, + pg_seq.seqmin, + pg_seq.seqcache, + pg_seq.seqcycle + FROM pg_catalog.pg_sequence AS pg_seq + INNER JOIN pg_catalog.pg_depend AS depend + ON + depend.classid = 'pg_class'::REGCLASS + AND pg_seq.seqrelid = depend.objid + AND depend.refclassid = 'pg_class'::REGCLASS + AND depend.deptype = 'i' + INNER JOIN pg_catalog.pg_attribute AS owner_attr + ON + depend.refobjid = owner_attr.attrelid + AND depend.refobjsubid = owner_attr.attnum + WHERE owner_attr.attidentity != '' +) + SELECT a.attname::TEXT AS column_name, COALESCE(coll.collname, '')::TEXT AS collation_name, @@ -63,6 +87,13 @@ SELECT )::TEXT AS default_value, a.attnotnull AS is_not_null, a.attlen AS column_size, + a.attidentity::TEXT AS identity_type, + identity_col_seq.seqstart AS start_value, + identity_col_seq.seqincrement AS increment_value, + identity_col_seq.seqmax AS max_value, + identity_col_seq.seqmin AS min_value, + identity_col_seq.seqcache AS cache_size, + identity_col_seq.seqcycle AS is_cycle, pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type FROM pg_catalog.pg_attribute AS a LEFT JOIN @@ -72,6 +103,11 @@ LEFT JOIN pg_catalog.pg_collation AS coll ON coll.oid = a.attcollation LEFT JOIN pg_catalog.pg_namespace AS collation_namespace ON collation_namespace.oid = coll.collnamespace +LEFT JOIN + identity_col_seq + ON + identity_col_seq.owner_relid = a.attrelid + AND identity_col_seq.owner_attnum = a.attnum WHERE a.attrelid = $1 AND a.attnum > 0 @@ -287,7 +323,7 @@ LEFT JOIN pg_catalog.pg_depend AS depend depend.classid = 'pg_class'::REGCLASS AND pg_seq.seqrelid = depend.objid AND depend.refclassid = 'pg_class'::REGCLASS - AND depend.deptype = 'a' + AND depend.deptype IN ('a', 'i') LEFT JOIN pg_catalog.pg_attribute AS owner_attr ON depend.refobjid = owner_attr.attrelid @@ -300,6 +336,9 @@ WHERE seq_ns.nspname NOT IN ('pg_catalog', 'information_schema') AND seq_ns.nspname !~ '^pg_toast' AND seq_ns.nspname !~ '^pg_temp' + -- Exclude sequences owned by identity columns. + -- These manifest as internal dependency on the column + AND (depend.deptype IS NULL OR depend.deptype != 'i') -- Exclude sequences belonging to extensions AND NOT EXISTS ( SELECT ext_depend.objid diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go index 53895b9..746a38f 100644 --- a/internal/queries/queries.sql.go +++ b/internal/queries/queries.sql.go @@ -7,6 +7,7 @@ package queries import ( "context" + "database/sql" "github.com/lib/pq" ) @@ -88,6 +89,30 @@ func (q *Queries) GetCheckConstraints(ctx context.Context) ([]GetCheckConstraint } const getColumnsForTable = `-- name: GetColumnsForTable :many +WITH identity_col_seq AS ( + SELECT + depend.refobjid AS owner_relid, + depend.refobjsubid AS owner_attnum, + pg_seq.seqstart, + pg_seq.seqincrement, + pg_seq.seqmax, + pg_seq.seqmin, + pg_seq.seqcache, + pg_seq.seqcycle + FROM pg_catalog.pg_sequence AS pg_seq + INNER JOIN pg_catalog.pg_depend AS depend + ON + depend.classid = 'pg_class'::REGCLASS + AND pg_seq.seqrelid = depend.objid + AND depend.refclassid = 'pg_class'::REGCLASS + AND depend.deptype = 'i' + INNER JOIN pg_catalog.pg_attribute AS owner_attr + ON + depend.refobjid = owner_attr.attrelid + AND depend.refobjsubid = owner_attr.attnum + WHERE owner_attr.attidentity != '' +) + SELECT a.attname::TEXT AS column_name, COALESCE(coll.collname, '')::TEXT AS collation_name, @@ -97,6 +122,13 @@ SELECT )::TEXT AS default_value, a.attnotnull AS is_not_null, a.attlen AS column_size, + a.attidentity::TEXT AS identity_type, + identity_col_seq.seqstart AS start_value, + identity_col_seq.seqincrement AS increment_value, + identity_col_seq.seqmax AS max_value, + identity_col_seq.seqmin AS min_value, + identity_col_seq.seqcache AS cache_size, + identity_col_seq.seqcycle AS is_cycle, pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type FROM pg_catalog.pg_attribute AS a LEFT JOIN @@ -106,6 +138,11 @@ LEFT JOIN pg_catalog.pg_collation AS coll ON coll.oid = a.attcollation LEFT JOIN pg_catalog.pg_namespace AS collation_namespace ON collation_namespace.oid = coll.collnamespace +LEFT JOIN + identity_col_seq + ON + identity_col_seq.owner_relid = a.attrelid + AND identity_col_seq.owner_attnum = a.attnum WHERE a.attrelid = $1 AND a.attnum > 0 @@ -120,6 +157,13 @@ type GetColumnsForTableRow struct { DefaultValue string IsNotNull bool ColumnSize int16 + IdentityType string + StartValue sql.NullInt64 + IncrementValue sql.NullInt64 + MaxValue sql.NullInt64 + MinValue sql.NullInt64 + CacheSize sql.NullInt64 + IsCycle sql.NullBool ColumnType string } @@ -139,6 +183,13 @@ func (q *Queries) GetColumnsForTable(ctx context.Context, attrelid interface{}) &i.DefaultValue, &i.IsNotNull, &i.ColumnSize, + &i.IdentityType, + &i.StartValue, + &i.IncrementValue, + &i.MaxValue, + &i.MinValue, + &i.CacheSize, + &i.IsCycle, &i.ColumnType, ); err != nil { return nil, err @@ -727,7 +778,7 @@ LEFT JOIN pg_catalog.pg_depend AS depend depend.classid = 'pg_class'::REGCLASS AND pg_seq.seqrelid = depend.objid AND depend.refclassid = 'pg_class'::REGCLASS - AND depend.deptype = 'a' + AND depend.deptype IN ('a', 'i') LEFT JOIN pg_catalog.pg_attribute AS owner_attr ON depend.refobjid = owner_attr.attrelid @@ -740,6 +791,9 @@ WHERE seq_ns.nspname NOT IN ('pg_catalog', 'information_schema') AND seq_ns.nspname !~ '^pg_toast' AND seq_ns.nspname !~ '^pg_temp' + -- Exclude sequences owned by identity columns. + -- These manifest as internal dependency on the column + AND (depend.deptype IS NULL OR depend.deptype != 'i') -- Exclude sequences belonging to extensions AND NOT EXISTS ( SELECT ext_depend.objid diff --git a/internal/schema/schema.go b/internal/schema/schema.go index d58bef1..68d66d8 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -199,22 +199,41 @@ func (t Table) IsPartition() bool { return t.ParentTable != nil } -type Column struct { - Name string - Type string - Collation SchemaQualifiedName - // If the column has a default value, this will be a SQL string representing that value. - // Examples: - // ''::text - // CURRENT_TIMESTAMP - // If empty, indicates that there is no default value. - Default string - IsNullable bool - - // Size is the number of bytes required to store the value. - // It is used for data-packing purposes - Size int // -} +type ColumnIdentityType string + +const ( + ColumnIdentityTypeAlways = "a" + ColumnIdentityTypeByDefault = "d" +) + +type ( + ColumnIdentity struct { + Type ColumnIdentityType + MinValue int64 + MaxValue int64 + StartValue int64 + Increment int64 + CacheSize int64 + Cycle bool + } + + Column struct { + Name string + Type string + Collation SchemaQualifiedName + // If the column has a default value, this will be a SQL string representing that value. + // Examples: + // ''::text + // CURRENT_TIMESTAMP + // If empty, indicates that there is no default value. + Default string + IsNullable bool + // Size is the number of bytes required to store the value. + // It is used for data-packing purposes + Size int + Identity *ColumnIdentity + } +) func (c Column) GetName() string { return c.Name @@ -822,6 +841,19 @@ func (s *schemaFetcher) buildTable( } } + var identity *ColumnIdentity + if len(column.IdentityType) > 0 { + identity = &ColumnIdentity{ + Type: ColumnIdentityType(column.IdentityType), + StartValue: column.StartValue.Int64, + Increment: column.IncrementValue.Int64, + MaxValue: column.MaxValue.Int64, + MinValue: column.MinValue.Int64, + CacheSize: column.CacheSize.Int64, + Cycle: column.IsCycle.Bool, + } + } + columns = append(columns, Column{ Name: column.ColumnName, Type: column.ColumnType, @@ -832,8 +864,9 @@ func (s *schemaFetcher) buildTable( // ''::text // CURRENT_TIMESTAMP // If empty, indicates that there is no default value. - Default: column.DefaultValue, - Size: int(column.ColumnSize), + Default: column.DefaultValue, + Size: int(column.ColumnSize), + Identity: identity, }) } diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index 6c54ad8..4d0b077 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -191,7 +191,7 @@ var ( TO PUBLIC USING (version > 0); `}, - expectedHash: "ffcf26204e89f536", + expectedHash: "4f6a01ac1a078624", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, @@ -494,7 +494,7 @@ var ( ALTER TABLE foo_fk_1 ADD CONSTRAINT foo_fk_1_fk FOREIGN KEY (author, content) REFERENCES foo_1 (author, content) NOT VALID; `}, - expectedHash: "481b62a68155716d", + expectedHash: "14fc890b05a1fa7b", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, @@ -834,20 +834,22 @@ var ( }, }, { - name: "Common Data Types", + name: "Common columns", ddl: []string{` CREATE TABLE foo ( - "varchar" VARCHAR(128) NOT NULL DEFAULT '', - "text" TEXT NOT NULL DEFAULT '', - "bool" BOOLEAN NOT NULL DEFAULT False, - "blob" BYTEA NOT NULL DEFAULT '', - "smallint" SMALLINT NOT NULL DEFAULT 0, - "real" REAL NOT NULL DEFAULT 0.0, - "double_precision" DOUBLE PRECISION NOT NULL DEFAULT 0.0, - "integer" INTEGER NOT NULL DEFAULT 0, - "big_integer" BIGINT NOT NULL DEFAULT 0, - "decimal" DECIMAL(65, 10) NOT NULL DEFAULT 0.0, - "serial" SERIAL NOT NULL + varchar VARCHAR(128) NOT NULL DEFAULT '', + text TEXT NOT NULL DEFAULT '', + bool BOOLEAN NOT NULL DEFAULT False, + blob BYTEA NOT NULL DEFAULT '', + smallint SMALLINT NOT NULL DEFAULT 0, + real REAL NOT NULL DEFAULT 0.0, + double_precision DOUBLE PRECISION NOT NULL DEFAULT 0.0, + integer INTEGER NOT NULL DEFAULT 0, + big_integer BIGINT NOT NULL DEFAULT 0, + decimal DECIMAL(65, 10) NOT NULL DEFAULT 0.0, + serial SERIAL NOT NULL, + identity_always BIGINT GENERATED ALWAYS AS IDENTITY ( MINVALUE 2 MAXVALUE 9 START 3 INCREMENT 4 CACHE 5 NO CYCLE ), + identity_default BIGINT GENERATED BY DEFAULT AS IDENTITY ( MINVALUE 20 MAXVALUE 90 START 30 INCREMENT 40 CACHE 50 CYCLE ) ); `}, expectedSchema: Schema{ @@ -869,6 +871,28 @@ var ( {Name: "big_integer", Type: "bigint", Default: "0", Size: 8}, {Name: "decimal", Type: "numeric(65,10)", Default: "0.0", Size: -1}, {Name: "serial", Type: "integer", Collation: SchemaQualifiedName{}, Default: "nextval('foo_serial_seq'::regclass)", IsNullable: false, Size: 4}, + {Name: "identity_always", Type: "bigint", Size: 8, + Identity: &ColumnIdentity{ + Type: ColumnIdentityTypeAlways, + MinValue: 2, + MaxValue: 9, + StartValue: 3, + Increment: 4, + CacheSize: 5, + Cycle: false, + }, + }, + {Name: "identity_default", Type: "bigint", Size: 8, + Identity: &ColumnIdentity{ + Type: ColumnIdentityTypeByDefault, + MinValue: 20, + MaxValue: 90, + StartValue: 30, + Increment: 40, + CacheSize: 50, + Cycle: true, + }, + }, }, CheckConstraints: nil, ReplicaIdentity: ReplicaIdentityDefault, diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go index 0bfe81c..abd9341 100644 --- a/pkg/diff/sql_generator.go +++ b/pkg/diff/sql_generator.go @@ -729,7 +729,11 @@ func (t *tableSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) { var columnDefs []string for _, column := range table.Columns { - columnDefs = append(columnDefs, "\t"+buildColumnDefinition(column)) + columnDef, err := buildColumnDefinition(column) + if err != nil { + return nil, fmt.Errorf("building column definition: %w", err) + } + columnDefs = append(columnDefs, "\t"+columnDef) } createTableSb := strings.Builder{} createTableSb.WriteString(fmt.Sprintf("CREATE TABLE %s (\n%s\n)", @@ -883,8 +887,8 @@ func (t *tableSQLVertexGenerator) alterBaseTable(diff tableDiff) ([]Statement, e tempCCs = append(tempCCs, tempCC) } - columnSQLVertexGenerator := columnSQLVertexGenerator{tableName: diff.new.SchemaQualifiedName} - columnGraph, err := diff.columnsDiff.resolveToSQLGraph(&columnSQLVertexGenerator) + columnSQLVertexGenerator := newColumnSQLVertexGenerator(diff.new.SchemaQualifiedName) + columnGraph, err := diff.columnsDiff.resolveToSQLGraph(columnSQLVertexGenerator) if err != nil { return nil, fmt.Errorf("resolving index diff: %w", err) } @@ -1125,9 +1129,17 @@ type columnSQLVertexGenerator struct { tableName schema.SchemaQualifiedName } +func newColumnSQLVertexGenerator(tableName schema.SchemaQualifiedName) *columnSQLVertexGenerator { + return &columnSQLVertexGenerator{tableName: tableName} +} + func (csg *columnSQLVertexGenerator) Add(column schema.Column) ([]Statement, error) { + columnDef, err := buildColumnDefinition(column) + if err != nil { + return nil, fmt.Errorf("building column definition: %w", err) + } return []Statement{{ - DDL: fmt.Sprintf("%s ADD COLUMN %s", alterTablePrefix(csg.tableName), buildColumnDefinition(column)), + DDL: fmt.Sprintf("%s ADD COLUMN %s", alterTablePrefix(csg.tableName), columnDef), Timeout: statementTimeoutDefault, LockTimeout: lockTimeoutDefault, }}, nil @@ -1155,20 +1167,30 @@ func (csg *columnSQLVertexGenerator) Alter(diff columnDiff) ([]Statement, error) var stmts []Statement alterColumnPrefix := fmt.Sprintf("%s ALTER COLUMN %s", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(newColumn.Name)) - if oldColumn.IsNullable != newColumn.IsNullable { - if newColumn.IsNullable { - stmts = append(stmts, Statement{ - DDL: fmt.Sprintf("%s DROP NOT NULL", alterColumnPrefix), - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - }) - } else { - stmts = append(stmts, Statement{ - DDL: fmt.Sprintf("%s SET NOT NULL", alterColumnPrefix), - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - }) - } + // Adding a "NOT NULL" constraint must come before updating a column to be an identity column, otherwise + // the add statement will fail because a column must be non-nullable to become an identity column. + if oldColumn.IsNullable != newColumn.IsNullable && !newColumn.IsNullable { + stmts = append(stmts, Statement{ + DDL: fmt.Sprintf("%s SET NOT NULL", alterColumnPrefix), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }) + } + + updateIdentityStmts, err := csg.buildUpdateIdentityStatements(oldColumn, newColumn) + if err != nil { + return nil, fmt.Errorf("building update identity statements: %w", err) + } + stmts = append(stmts, updateIdentityStmts...) + + // Removing a "NOT NULL" constraint must come after updating a column to no longer be an identity column, otherwise + // the "DROP NOT NULL" statement will fail because the column will still be an identity column. + if oldColumn.IsNullable != newColumn.IsNullable && newColumn.IsNullable { + stmts = append(stmts, Statement{ + DDL: fmt.Sprintf("%s DROP NOT NULL", alterColumnPrefix), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }) } if len(oldColumn.Default) > 0 && len(newColumn.Default) == 0 { @@ -1186,8 +1208,7 @@ func (csg *columnSQLVertexGenerator) Alter(diff columnDiff) ([]Statement, error) stmts = append(stmts, []Statement{ csg.generateTypeTransformationStatement( - alterColumnPrefix, - schema.EscapeIdentifier(newColumn.Name), + diff.new, oldColumn.Type, newColumn.Type, newColumn.Collation, @@ -1225,8 +1246,7 @@ func (csg *columnSQLVertexGenerator) Alter(diff columnDiff) ([]Statement, error) } func (csg *columnSQLVertexGenerator) generateTypeTransformationStatement( - prefix string, - name string, + col schema.Column, oldType string, newType string, newTypeCollation schema.SchemaQualifiedName, @@ -1235,9 +1255,9 @@ func (csg *columnSQLVertexGenerator) generateTypeTransformationStatement( strings.EqualFold(newType, "timestamp without time zone") { return Statement{ DDL: fmt.Sprintf("%s SET DATA TYPE %s using to_timestamp(%s / 1000)", - prefix, + csg.alterColumnPrefix(col), newType, - name, + schema.EscapeIdentifier(col.Name), ), Timeout: statementTimeoutDefault, LockTimeout: lockTimeoutDefault, @@ -1260,10 +1280,10 @@ func (csg *columnSQLVertexGenerator) generateTypeTransformationStatement( return Statement{ DDL: fmt.Sprintf("%s SET DATA TYPE %s %susing %s::%s", - prefix, + csg.alterColumnPrefix(col), newType, collationModifier, - name, + schema.EscapeIdentifier(col.Name), newType, ), Timeout: statementTimeoutDefault, @@ -1279,6 +1299,78 @@ func (csg *columnSQLVertexGenerator) generateTypeTransformationStatement( } } +func (csg *columnSQLVertexGenerator) buildUpdateIdentityStatements(old, new schema.Column) ([]Statement, error) { + if cmp.Equal(old.Identity, new.Identity) { + return nil, nil + } + + // Drop the old identity + if new.Identity == nil { + // ALTER [ COLUMN ] column_name DROP IDENTITY [ IF EXISTS ] + return []Statement{{ + DDL: fmt.Sprintf("%s DROP IDENTITY", csg.alterColumnPrefix(old)), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }}, nil + } + + // Add the new identity + if old.Identity == nil { + def, err := buildColumnIdentityDefinition(*new.Identity) + if err != nil { + return nil, fmt.Errorf("building column identity definition: %w", err) + } + // ALTER [ COLUMN ] column_name ADD GENERATED { ALWAYS | BY DEFAULT } AS IDENTITY [ ( sequence_options ) ] + return []Statement{{ + DDL: fmt.Sprintf("%s ADD %s", csg.alterColumnPrefix(new), def), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }}, nil + } + + // Alter the existing identity + var modifications []string + if old.Identity.Type != new.Identity.Type { + typeModifier, err := columnIdentityTypeToModifier(new.Identity.Type) + if err != nil { + return nil, fmt.Errorf("column identity type modifier: %w", err) + } + modifications = append(modifications, fmt.Sprintf("\tSET GENERATED %s", typeModifier)) + } + if old.Identity.Increment != new.Identity.Increment { + modifications = append(modifications, fmt.Sprintf("\tSET INCREMENT BY %d", new.Identity.Increment)) + } + if old.Identity.MinValue != new.Identity.MinValue { + modifications = append(modifications, fmt.Sprintf("\tSET MINVALUE %d", new.Identity.MinValue)) + } + if old.Identity.MaxValue != new.Identity.MaxValue { + modifications = append(modifications, fmt.Sprintf("\tSET MAXVALUE %d", new.Identity.MaxValue)) + } + if old.Identity.StartValue != new.Identity.StartValue { + modifications = append(modifications, fmt.Sprintf("\tSET START %d", new.Identity.StartValue)) + } + if old.Identity.CacheSize != new.Identity.CacheSize { + modifications = append(modifications, fmt.Sprintf("\tSET CACHE %d", new.Identity.CacheSize)) + } + if old.Identity.Cycle != new.Identity.Cycle { + cycleModifier := "" + if !new.Identity.Cycle { + cycleModifier = "NO " + } + modifications = append(modifications, fmt.Sprintf("\tSET %sCYCLE", cycleModifier)) + } + // ALTER [ COLUMN ] column_name { SET GENERATED { ALWAYS | BY DEFAULT } | SET sequence_option | RESTART [ [ WITH ] restart ] } [...] + return []Statement{{ + DDL: fmt.Sprintf("%s\n%s", csg.alterColumnPrefix(new), strings.Join(modifications, "\n")), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }}, nil +} + +func (csg *columnSQLVertexGenerator) alterColumnPrefix(col schema.Column) string { + return fmt.Sprintf("%s ALTER COLUMN %s", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(col.Name)) +} + func (csg *columnSQLVertexGenerator) GetSQLVertexId(column schema.Column) string { return buildColumnVertexId(column.Name) } @@ -2630,7 +2722,7 @@ func alterTablePrefix(table schema.SchemaQualifiedName) string { return fmt.Sprintf("ALTER TABLE %s", table.GetFQEscapedName()) } -func buildColumnDefinition(column schema.Column) string { +func buildColumnDefinition(column schema.Column) (string, error) { sb := strings.Builder{} sb.WriteString(fmt.Sprintf("%s %s", schema.EscapeIdentifier(column.Name), column.Type)) if column.IsCollated() { @@ -2642,5 +2734,37 @@ func buildColumnDefinition(column schema.Column) string { if len(column.Default) > 0 { sb.WriteString(fmt.Sprintf(" DEFAULT %s", column.Default)) } - return sb.String() + if column.Identity != nil { + identityDef, err := buildColumnIdentityDefinition(*column.Identity) + if err != nil { + return "", fmt.Errorf("building column identity definition: %w", err) + } + sb.WriteString(" " + identityDef) + } + return sb.String(), nil +} + +func buildColumnIdentityDefinition(identity schema.ColumnIdentity) (string, error) { + typeModifier, err := columnIdentityTypeToModifier(identity.Type) + if err != nil { + return "", fmt.Errorf("column identity type modifier: %w", err) + } + + cycleModifier := "" + if !identity.Cycle { + cycleModifier = "NO " + } + + return fmt.Sprintf("GENERATED %s AS IDENTITY (INCREMENT BY %d MINVALUE %d MAXVALUE %d START WITH %d CACHE %d %sCYCLE)", typeModifier, identity.Increment, identity.MinValue, identity.MaxValue, identity.StartValue, identity.CacheSize, cycleModifier), nil +} + +func columnIdentityTypeToModifier(val schema.ColumnIdentityType) (string, error) { + switch val { + case schema.ColumnIdentityTypeAlways: + return "ALWAYS", nil + case schema.ColumnIdentityTypeByDefault: + return "BY DEFAULT", nil + default: + return "", fmt.Errorf("unknown identity type %q", val) + } }