From 702acb18c5ddf20e2c69280456d06b80f46e73bb Mon Sep 17 00:00:00 2001 From: Anthony Alayo Date: Mon, 3 Feb 2025 20:19:39 -0800 Subject: [PATCH] making names for check constraints optional --- drizzle-kit/src/jsonStatements.ts | 9 ++++- drizzle-kit/src/serializer/pgSchema.ts | 8 ++-- drizzle-kit/src/serializer/pgSerializer.ts | 47 +++++++++++++--------- drizzle-kit/src/snapshotsDiffer.ts | 4 +- drizzle-kit/src/sqlgenerator.ts | 21 ++++++++-- drizzle-kit/tests/pg-checks.test.ts | 47 ++++++++++++++++++++++ drizzle-kit/tests/pg-domains.test.ts | 27 +++++++++++++ drizzle-orm/src/pg-core/checks.ts | 14 +++++-- drizzle-zod/tests/pg-checks.test.ts | 14 +++---- 9 files changed, 149 insertions(+), 42 deletions(-) diff --git a/drizzle-kit/src/jsonStatements.ts b/drizzle-kit/src/jsonStatements.ts index 404442cdb..5283f6e9a 100644 --- a/drizzle-kit/src/jsonStatements.ts +++ b/drizzle-kit/src/jsonStatements.ts @@ -1053,6 +1053,11 @@ export const prepareDomainJson = ( type: 'create' | 'alter' | 'drop', action?: JsonAlterDomainStatement['action'], ): JsonDomainStatement => { + let checkConstraints; + if (domain.checkConstraints) { + checkConstraints = Object.values(domain.checkConstraints); + } + if (type === 'create') { return { type: 'create_domain', @@ -1061,7 +1066,7 @@ export const prepareDomainJson = ( notNull: domain.notNull, defaultValue: domain.defaultValue, baseType: domain.baseType, - checkConstraints: Object.values(domain.checkConstraints), + checkConstraints, }; } else if (type === 'drop') { return { @@ -1078,7 +1083,7 @@ export const prepareDomainJson = ( name: domain.name, schema: domain.schema, defaultValue: domain.defaultValue, - checkConstraints: Object.values(domain.checkConstraints), + checkConstraints, }; }; diff --git a/drizzle-kit/src/serializer/pgSchema.ts b/drizzle-kit/src/serializer/pgSchema.ts index ab414ec8d..bd4997000 100644 --- a/drizzle-kit/src/serializer/pgSchema.ts +++ b/drizzle-kit/src/serializer/pgSchema.ts @@ -29,7 +29,7 @@ const tableV2 = object({ }).strict(); const checkConstraint = object({ - name: string(), + name: string().optional(), value: string(), }).strict(); @@ -39,7 +39,7 @@ const domainSchema = object({ baseType: string(), notNull: boolean().optional(), defaultValue: string().optional(), - checkConstraints: record(string(), checkConstraint).default({}), + checkConstraints: record(string(), checkConstraint).optional(), }).strict(); const enumSchemaV1 = object({ @@ -473,7 +473,7 @@ const domainSquashed = object({ baseType: string(), notNull: boolean().optional(), defaultValue: string().optional(), - checkConstraints: record(string(), string()), + checkConstraints: record(string(), string()).optional(), }).strict(); const tableSquashed = object({ @@ -895,7 +895,7 @@ export const squashPgScheme = ( const mappedDomains = mapValues(json.domains, (domain) => { const squashedDomainChecks = mapValues( - domain.checkConstraints, + domain.checkConstraints ?? {}, (check) => PgSquasher.squashCheck(check), ); diff --git a/drizzle-kit/src/serializer/pgSerializer.ts b/drizzle-kit/src/serializer/pgSerializer.ts index c2169fcf7..5e7234290 100644 --- a/drizzle-kit/src/serializer/pgSerializer.ts +++ b/drizzle-kit/src/serializer/pgSerializer.ts @@ -536,11 +536,18 @@ export const generatePgSnapshot = ( }; }); - checks.forEach((check) => { - const checkName = check.name; + checks.forEach((check, index) => { + const tableKey = `"${schema ?? 'public'}"."${tableName}"`; + + // you can have multiple unnamed checks per table (using the default above) + let defaultCheckName = `${tableName}_check` + let checkName = check.name ?? defaultCheckName; + if (checksInTable[tableKey]?.includes(checkName) && checkName == defaultCheckName) { + checkName += `_${index}`; + } if (typeof checksInTable[`"${schema ?? 'public'}"."${tableName}"`] !== 'undefined') { - if (checksInTable[`"${schema ?? 'public'}"."${tableName}"`].includes(check.name)) { + if (checksInTable[`"${schema ?? 'public'}"."${tableName}"`].includes(checkName)) { console.log( `\n${ withStyle.errorWarning( @@ -564,11 +571,11 @@ export const generatePgSnapshot = ( } checksInTable[`"${schema ?? 'public'}"."${tableName}"`].push(checkName); } else { - checksInTable[`"${schema ?? 'public'}"."${tableName}"`] = [check.name]; + checksInTable[`"${schema ?? 'public'}"."${tableName}"`] = [checkName]; } checksObject[checkName] = { - name: checkName, + name: check.name ?? '', // don't squash and include the default name value: dialect.sqlToQuery(check.value).sql, }; }); @@ -875,15 +882,15 @@ export const generatePgSnapshot = ( // Process check constraints similar to tables const checksObject: Record = {}; - obj.checkConstraints?.forEach((checkConstraint) => { - const checkName = checkConstraint.name; - const checkValue = dialect.sqlToQuery(checkConstraint.value).sql; - + obj.checkConstraints?.forEach((checkConstraint, index) => { // Validate unique constraint names within domain const domainKey = `"${obj.schema ?? 'public'}"."${obj.domainName}"`; - if (checksInTable[domainKey]?.includes(checkName)) { - console.error(`Duplicate check constraint name ${checkName} in domain ${domainKey}`); - process.exit(1); + + // you can have multiple unnamed checks per domain (using the default above) + let defaultCheckName = `${obj.domainName}_check` + let checkName = checkConstraint.name ?? defaultCheckName; + if (checksInTable[domainKey]?.includes(checkName) && checkName == defaultCheckName) { + checkName += `_${index}`; } checksInTable[domainKey] = checksInTable[domainKey] @@ -891,8 +898,8 @@ export const generatePgSnapshot = ( : [checkName]; checksObject[checkName] = { - name: checkName, - value: checkValue, + name: checkConstraint.name ?? '', // don't squash and include the default name + value: dialect.sqlToQuery(checkConstraint.value).sql, }; }); @@ -1138,7 +1145,6 @@ WHERE const schemaName = domain.domain_schema || 'public'; const key = `${schemaName}.${domain.domain_name}`; - // Initialize the domain if it doesn't exist if (!domainsToReturn[key]) { domainsToReturn[key] = { name: domain.domain_name, @@ -1146,12 +1152,15 @@ WHERE baseType: domain.base_type, notNull: domain.not_null, defaultValue: domain.default_value, - checkConstraints: {}, // Now using checkConstraints (plural) as a Record }; } // Add the check constraint if present in this row if (domain.constraint_name && domain.domain_constraint) { + if (!domainsToReturn[key].checkConstraints) { + domainsToReturn[key].checkConstraints = {}; + } + domainsToReturn[key].checkConstraints[domain.constraint_name] = { name: domain.constraint_name, value: domain.domain_constraint, @@ -1330,7 +1339,7 @@ WHERE const tableChecks = await db.query(`SELECT tc.constraint_name, tc.constraint_type, - pg_get_constraintdef(con.oid) AS constraint_definition + pg_get_expr(con.conbin, con.conrelid) AS check_expression FROM information_schema.table_constraints AS tc JOIN pg_constraint AS con @@ -1457,11 +1466,9 @@ WHERE for (const checks of tableChecks) { // CHECK (((email)::text <> 'test@gmail.com'::text)) // Where (email) is column in table - let checkValue: string = checks.constraint_definition; + const checkValue: string = checks.check_expression; const constraintName: string = checks.constraint_name; - checkValue = checkValue.replace(/^CHECK\s*\(\(/, '').replace(/\)\)\s*$/, ''); - checkConstraints[constraintName] = { name: constraintName, value: checkValue, diff --git a/drizzle-kit/src/snapshotsDiffer.ts b/drizzle-kit/src/snapshotsDiffer.ts index e5d72fad0..1bf3ade95 100644 --- a/drizzle-kit/src/snapshotsDiffer.ts +++ b/drizzle-kit/src/snapshotsDiffer.ts @@ -250,7 +250,7 @@ const domainSchema = object({ baseType: string(), notNull: boolean().optional(), defaultValue: string().optional(), - checkConstraints: record(string(), string()).default({}), + checkConstraints: record(string(), string()).optional(), }).strict(); const changedDomainSchema = object({ @@ -259,7 +259,7 @@ const changedDomainSchema = object({ baseType: string(), notNull: boolean().optional(), defaultValue: string().optional(), - checkConstraints: record(string(), string()).default({}), + checkConstraints: record(string(), string()).optional(), }).strict(); const enumSchema = object({ diff --git a/drizzle-kit/src/sqlgenerator.ts b/drizzle-kit/src/sqlgenerator.ts index 6ba0fb8a0..8137e0274 100644 --- a/drizzle-kit/src/sqlgenerator.ts +++ b/drizzle-kit/src/sqlgenerator.ts @@ -476,7 +476,12 @@ class PgCreateTableConvertor extends Convertor { for (const checkConstraint of checkConstraints) { statement += ',\n'; const unsquashedCheck = PgSquasher.unsquashCheck(checkConstraint); - statement += `\tCONSTRAINT "${unsquashedCheck.name}" CHECK (${unsquashedCheck.value})`; + + if(unsquashedCheck.name) { + statement += `\tCONSTRAINT "${unsquashedCheck.name}" CHECK (${unsquashedCheck.value})`; + } else { + statement += `\tCHECK (${unsquashedCheck.value})`; + } } } @@ -1398,7 +1403,12 @@ class CreateDomainConvertor extends DomainConvertor { if (checkConstraints && checkConstraints.length > 0) { for (const checkConstraint of checkConstraints) { const unsquashedCheck = PgSquasher.unsquashCheck(checkConstraint); - statement += ` CONSTRAINT ${unsquashedCheck.name} CHECK (${unsquashedCheck.value})`; + + if (unsquashedCheck.name) { + statement += ` CONSTRAINT ${unsquashedCheck.name} CHECK (${unsquashedCheck.value})`; + } else { + statement += ` CHECK (${unsquashedCheck.value})`; + } } } @@ -1422,7 +1432,12 @@ class AlterDomainConvertor extends DomainConvertor { if (checkConstraints && checkConstraints.length > 0) { for (const checkConstraint of checkConstraints) { const unsquashedCheck = PgSquasher.unsquashCheck(checkConstraint); - statement += ` ADD CONSTRAINT ${unsquashedCheck.name} CHECK (${unsquashedCheck.value})`; + + if (unsquashedCheck.name) { + statement += ` ADD CONSTRAINT ${unsquashedCheck.name} CHECK (${unsquashedCheck.value})`; + } else { + statement += ` ADD CHECK (${unsquashedCheck.value})`; + } } } break; diff --git a/drizzle-kit/tests/pg-checks.test.ts b/drizzle-kit/tests/pg-checks.test.ts index 8033aacef..8902e52ff 100644 --- a/drizzle-kit/tests/pg-checks.test.ts +++ b/drizzle-kit/tests/pg-checks.test.ts @@ -280,3 +280,50 @@ test('create checks with same names', async (t) => { await expect(diffTestSchemas({}, to, [])).rejects.toThrowError(); }); + +test('create unnamed checks', async (t) => { + const to = { + users: pgTable('users', { + id: serial('id').primaryKey(), + age: integer('age'), + }, (table) => ({ + checkConstraint: check(sql`${table.age} > 21`), + })), + }; + + const { sqlStatements, statements } = await diffTestSchemas({}, to, []); + + expect(statements.length).toBe(1); + expect(statements[0]).toStrictEqual({ + type: 'create_table', + tableName: 'users', + schema: '', + columns: [ + { + name: 'id', + type: 'serial', + notNull: true, + primaryKey: true, + }, + { + name: 'age', + type: 'integer', + notNull: false, + primaryKey: false, + }, + ], + compositePKs: [], + checkConstraints: [';"users"."age" > 21'], + compositePkName: '', + uniqueConstraints: [], + isRLSEnabled: false, + policies: [], + } as JsonCreateTableStatement); + + expect(sqlStatements.length).toBe(1); + expect(sqlStatements[0]).toBe(`CREATE TABLE "users" ( +\t"id" serial PRIMARY KEY NOT NULL, +\t"age" integer, +\tCHECK ("users"."age" > 21) +);\n`); +}); diff --git a/drizzle-kit/tests/pg-domains.test.ts b/drizzle-kit/tests/pg-domains.test.ts index 5d8172580..08c23e88b 100644 --- a/drizzle-kit/tests/pg-domains.test.ts +++ b/drizzle-kit/tests/pg-domains.test.ts @@ -284,3 +284,30 @@ test('domains #11 alter domain to drop default value', async () => { defaultValue: undefined, }); }); + +test('domains #12 create domain with unnamed constraint', async () => { + const to = { + domain: pgDomain('domain', 'text', { + checkConstraints: [check(sql`VALUE ~ '^[A-Za-z]+$'`)], + }), + }; + + const { statements, sqlStatements } = await diffTestSchemas({}, to, []); + + expect(sqlStatements.length).toBe(1); + expect(sqlStatements[0]).toBe( + `CREATE DOMAIN "public"."domain" AS text CHECK (VALUE ~ '^[A-Za-z]+$');`, + ); + expect(statements.length).toBe(1); + expect(statements[0]).toStrictEqual({ + type: 'create_domain', + name: 'domain', + schema: 'public', + baseType: 'text', + notNull: false, + defaultValue: undefined, + checkConstraints: [ + ";VALUE ~ '^[A-Za-z]+$'", + ], + }); +}); diff --git a/drizzle-orm/src/pg-core/checks.ts b/drizzle-orm/src/pg-core/checks.ts index 43740427b..cc6061d2c 100644 --- a/drizzle-orm/src/pg-core/checks.ts +++ b/drizzle-orm/src/pg-core/checks.ts @@ -7,7 +7,7 @@ export class CheckBuilder { protected brand!: 'PgConstraintBuilder'; - constructor(public name: string, public value: SQL) {} + constructor(public name: string | undefined, public value: SQL) {} /** @internal */ build(table: PgTable): Check { @@ -18,7 +18,7 @@ export class CheckBuilder { export class Check { static readonly [entityKind]: string = 'PgCheck'; - readonly name: string; + readonly name?: string; readonly value: SQL; constructor(public table: PgTable, builder: CheckBuilder) { @@ -27,6 +27,12 @@ export class Check { } } -export function check(name: string, value: SQL): CheckBuilder { - return new CheckBuilder(name, value); +export function check(value: SQL): CheckBuilder; +export function check(name: string, value: SQL): CheckBuilder; +export function check(nameOrValue: string | SQL, maybeValue?: SQL): CheckBuilder { + if (maybeValue === undefined) { + // Only one argument: treat it as the SQL value. + return new CheckBuilder(undefined, nameOrValue as SQL); + } + return new CheckBuilder(nameOrValue as string, maybeValue); } diff --git a/drizzle-zod/tests/pg-checks.test.ts b/drizzle-zod/tests/pg-checks.test.ts index 54478b7a5..1495d21a3 100644 --- a/drizzle-zod/tests/pg-checks.test.ts +++ b/drizzle-zod/tests/pg-checks.test.ts @@ -1,10 +1,10 @@ -import {type Equal, sql} from 'drizzle-orm'; -import {check, integer, pgDomain, pgTable, serial, text,} from 'drizzle-orm/pg-core'; -import {test} from 'vitest'; -import {z} from 'zod'; -import {CONSTANTS} from '~/constants.ts'; -import {createSelectSchema} from '../src'; -import {Expect, expectSchemaShape} from './utils.ts'; +import { type Equal, sql } from 'drizzle-orm'; +import { check, integer, pgDomain, pgTable, serial, text } from 'drizzle-orm/pg-core'; +import { test } from 'vitest'; +import { z } from 'zod'; +import { CONSTANTS } from '~/constants.ts'; +import { createSelectSchema } from '../src'; +import { Expect, expectSchemaShape } from './utils.ts'; // TODO think about what to do with the existing filters being added when check constraints are involved const integerSchema = z.number().min(CONSTANTS.INT32_MIN).max(CONSTANTS.INT32_MAX).int();