diff --git a/src/main/kotlin/org/jetbrains/exposed/dao/Entity.kt b/src/main/kotlin/org/jetbrains/exposed/dao/Entity.kt index 13fb7b15b1..8c5df410c0 100644 --- a/src/main/kotlin/org/jetbrains/exposed/dao/Entity.kt +++ b/src/main/kotlin/org/jetbrains/exposed/dao/Entity.kt @@ -380,7 +380,7 @@ class EntityCache { val insertedTables = inserts.keys - for (t in sortTablesByReferences(tables)) { + for (t in SchemaUtils.sortTablesByReferences(tables)) { flushInserts(t as IdTable<*>) } @@ -466,36 +466,6 @@ class EntityCache { it.expireCache() } } - - fun sortTablesByReferences(tables: Iterable) = addDependencies(tables).toCollection(arrayListOf()).run { - if(this.count() <= 1) return this - val canBeReferenced = arrayListOf
() - do { - val (movable, others) = partition { - it.columns.all { it.referee == null || canBeReferenced.contains(it.referee!!.table) || it.referee!!.table == it.table} - } - canBeReferenced.addAll(movable) - this.removeAll(movable) - } while (others.isNotEmpty() && movable.isNotEmpty()) - canBeReferenced.addAll(this) - canBeReferenced - } - - fun addDependencies(tables: Iterable
): Iterable
{ - val workset = HashSet
() - - fun checkTable(table: Table) { - if (workset.add(table)) { - for (c in table.columns) { - c.referee?.table?.let { checkTable(it) } - } - } - } - - for (t in tables) checkTable(t) - - return workset - } } } diff --git a/src/main/kotlin/org/jetbrains/exposed/sql/Constraints.kt b/src/main/kotlin/org/jetbrains/exposed/sql/Constraints.kt index 672b187e6c..d075238028 100644 --- a/src/main/kotlin/org/jetbrains/exposed/sql/Constraints.kt +++ b/src/main/kotlin/org/jetbrains/exposed/sql/Constraints.kt @@ -38,9 +38,11 @@ data class ForeignKeyConstraint(val fkName: String, val refereeTable: String, va companion object { fun from(column: Column<*>): ForeignKeyConstraint { assert(column.referee != null && (column.onDelete != null || column.onUpdate != null)) { "$column does not reference anything" } - val s = TransactionManager.current() - return ForeignKeyConstraint("", s.identity(column.referee!!.table), - s.identity(column.referee!!), s.identity(column.table), s.identity(column), + val referee = column.referee!! + val t = TransactionManager.current() + val refName = t.quoteIfNecessary(t.cutIfNecessary("fk_${referee.table.tableName}_${referee.name}_${column.name}")) + return ForeignKeyConstraint(refName, t.identity(referee.table), t.identity(referee), + t.identity(column.table), t.identity(column), column.onUpdate ?: ReferenceOption.NO_ACTION, column.onDelete ?: ReferenceOption.NO_ACTION) } @@ -56,7 +58,7 @@ data class ForeignKeyConstraint(val fkName: String, val refereeTable: String, va } } - override fun createStatement() = listOf("ALTER TABLE $referencedTable ADD" + if (fkName.isNotBlank()) " CONSTRAINT $fkName" else "" + foreignKeyPart) + override fun createStatement() = listOf("ALTER TABLE $referencedTable ADD" + (if (fkName.isNotBlank()) " CONSTRAINT $fkName" else "") + foreignKeyPart) override fun dropStatement() = listOf("ALTER TABLE $refereeTable DROP " + when (currentDialect) { diff --git a/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt b/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt index e190ca58af..d9b3478614 100644 --- a/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt +++ b/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt @@ -1,33 +1,89 @@ package org.jetbrains.exposed.sql -import org.jetbrains.exposed.dao.EntityCache import org.jetbrains.exposed.sql.transactions.TransactionManager import org.jetbrains.exposed.sql.vendors.currentDialect import org.jetbrains.exposed.sql.vendors.inProperCase import java.util.* object SchemaUtils { - fun createStatements(vararg tables: Table): List { - val statements = ArrayList() - if (tables.isEmpty()) - return statements + private class TableDepthGraph(val tables: List
) { + val graph = fetchAllTables().associate { t -> + t to t.columns.mapNotNull { c -> + c.referee?.let{ it.table to c.columnType.nullable } + }.toMap() + } + + private fun fetchAllTables(): HashSet
{ + val result = HashSet
() + + fun parseTable(table: Table) { + if (result.add(table)) { + table.columns.forEach { + it.referee?.table?.let(::parseTable) + } + } + } + tables.forEach(::parseTable) + return result + } + + fun sorted() : List
{ + val visited = mutableSetOf
() + val result = arrayListOf
() - val newTables = ArrayList
() + fun traverse(table: Table) { + if (table !in visited) { + visited += table + graph[table]!!.forEach { t, u -> + if (t !in visited) { + traverse(t) + } + } + result += table + } + } - for (table in EntityCache.sortTablesByReferences(tables.toList())) { + tables.forEach(::traverse) + return result + } - if (table.exists()) continue else newTables.add(table) + fun hasCycle() : Boolean { + val visited = mutableSetOf
() + val recursion = mutableSetOf
() - // create table - statements.addAll(table.ddl) + val sortedTables = sorted() - // create indices - for (index in table.indices) { - statements.addAll(createIndex(index)) + fun traverse(table: Table) : Boolean { + if (table in recursion) return true + if (table in visited) return false + recursion += table + visited += table + return if (graph[table]!!.any{ traverse(it.key) }) { + true + } else { + recursion -= table + false + } } + return sortedTables.any { traverse(it) } } + } + + fun sortTablesByReferences(tables: Iterable
) = TableDepthGraph(tables.toList()).sorted() + fun checkCycle(vararg tables: Table) = TableDepthGraph(tables.toList()).hasCycle() + + fun createStatements(vararg tables: Table): List { + if (tables.isEmpty()) + return emptyList() - return statements + val toCreate = sortTablesByReferences(tables.toList()).filterNot { it.exists() } + val alters = arrayListOf() + return toCreate.flatMap { table -> + val (create, alter) = table.ddl.partition { it.startsWith("CREATE ") } + val indicesDDL = table.indices.flatMap { createIndex(it) } + alters += alter + create + indicesDDL + } + alters } fun createSequence(name: String) = Seq(name).createStatement() @@ -185,7 +241,7 @@ object SchemaUtils { if (tables.isEmpty()) return val transaction = TransactionManager.current() transaction.flushCache() - var tablesForDeletion = EntityCache + var tablesForDeletion = SchemaUtils .sortTablesByReferences(tables.toList()) .reversed() .filter { it in tables } diff --git a/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt b/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt index a9bba7ce43..f482e740d5 100644 --- a/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt +++ b/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt @@ -3,10 +3,7 @@ package org.jetbrains.exposed.sql import org.jetbrains.exposed.dao.EntityID import org.jetbrains.exposed.dao.IdTable import org.jetbrains.exposed.sql.transactions.TransactionManager -import org.jetbrains.exposed.sql.vendors.OracleDialect -import org.jetbrains.exposed.sql.vendors.currentDialect -import org.jetbrains.exposed.sql.vendors.currentDialectIfAvailable -import org.jetbrains.exposed.sql.vendors.inProperCase +import org.jetbrains.exposed.sql.vendors.* import org.joda.time.DateTime import java.math.BigDecimal import java.sql.Blob @@ -485,6 +482,8 @@ open class Table(name: String = ""): ColumnSet(), DdlAware { Seq(it).createStatement() }.orEmpty() + val addForeignKeysInAlterPart = SchemaUtils.checkCycle(this) && currentDialect !is SQLiteDialect + val createTableDDL = buildString { append("CREATE TABLE ") if (currentDialect.supportsIfNotExists) { @@ -498,24 +497,31 @@ open class Table(name: String = ""): ColumnSet(), DdlAware { append(", $it") } } - columns.filter { it.referee != null }.let { references -> - if (references.isNotEmpty()) { - append(references.joinToString(prefix = ", ", separator = ", ") { ForeignKeyConstraint.from(it).foreignKeyPart }) + + if (!addForeignKeysInAlterPart) { + columns.filter { it.referee != null }.takeIf { it.isNotEmpty() }?.let { references -> + references.joinTo(this, prefix = ", ", separator = ", ") { ForeignKeyConstraint.from(it).foreignKeyPart } } } + if (checkConstraints.isNotEmpty()) { - append( - checkConstraints.mapIndexed { index, (name, op) -> - val resolvedName = name.takeIf { it.isNotBlank() } ?: "check_${tableName}_$index" - CheckConstraint.from(this@Table, resolvedName, op).checkPart - }.joinToString(prefix = ",", separator = ",") - ) + checkConstraints.asSequence().mapIndexed { index, (name, op) -> + val resolvedName = name.takeIf { it.isNotBlank() } ?: "check_${tableName}_$index" + CheckConstraint.from(this@Table, resolvedName, op).checkPart + }.joinTo(this, prefix = ",", separator = ",") } append(")") } } - return seqDDL + createTableDDL + + val constraintSQL = if (addForeignKeysInAlterPart) { + columns.filter { it.referee != null }.flatMap { + ForeignKeyConstraint.from(it).createStatement() + } + } else emptyList() + + return seqDDL + createTableDDL + constraintSQL } internal fun primaryKeyConstraint(): String? { @@ -538,12 +544,16 @@ open class Table(name: String = ""): ColumnSet(), DdlAware { val dropTableDDL = buildString { append("DROP TABLE ") if (currentDialect.supportsIfNotExists) { - append(" IF EXISTS ") + append("IF EXISTS ") } append(TransactionManager.current().identity(this@Table)) if (currentDialectIfAvailable is OracleDialect) { append(" CASCADE CONSTRAINTS") } + + if (currentDialectIfAvailable is PostgreSQLDialect && SchemaUtils.checkCycle(this@Table)) { + append(" CASCADE") + } } val seqDDL = autoIncColumn?.autoIncSeqName?.let { Seq(it).dropStatement() diff --git a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/DDLTests.kt b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/DDLTests.kt index 8e8ebbc845..6f9963d0dd 100644 --- a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/DDLTests.kt +++ b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/DDLTests.kt @@ -433,7 +433,7 @@ class DDLTests : DatabaseTestsBase() { } } - withTables(excludeSettings = listOf(TestDB.H2, TestDB.H2_MYSQL, TestDB.SQLITE), tables = initialTable) { + withTables(excludeSettings = listOf(TestDB.H2, TestDB.H2_MYSQL, TestDB.SQLITE), tables = *arrayOf(initialTable)) { assertEquals("ALTER TABLE ${tableName.inProperCase()} ADD ${"id".inProperCase()} ${t.id.columnType.sqlType()} PRIMARY KEY", t.id.ddl) assertEquals(1, currentDialect.tableColumns(t)[t]!!.size) SchemaUtils.createMissingTablesAndColumns(t) @@ -489,6 +489,36 @@ class DDLTests : DatabaseTestsBase() { } } + object Table1 : IntIdTable() { + val table2 = reference("teamId", Table2, onDelete = ReferenceOption.CASCADE) + } + + object Table2 : IntIdTable() { + val table1 = optReference("teamId", Table1, onDelete = ReferenceOption.SET_NULL) + } + + @Test fun testCrossReference() { + withTables(Table2, Table1) { + val table2id = Table2.insertAndGetId{} + val table1id = Table1.insertAndGetId { + it[Table1.table2] = table2id + } + + Table2.insertAndGetId { + it[Table2.table1] = table1id + } + + assertEquals(1, Table1.selectAll().count()) + assertEquals(2, Table2.selectAll().count()) + if (currentDialect is MysqlDialect) { + exec("SET foreign_key_checks = 0;") + } + if (currentDialect is PostgreSQLDialect) { + exec("set constraints all deferred;") + } + } + } + @Test fun testUUIDColumnType() { val Node = object: Table("node") { val uuid = uuid("uuid").primaryKey() diff --git a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/EntityTests.kt b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/EntityTests.kt index 2099fb1d77..1821a857c3 100644 --- a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/EntityTests.kt +++ b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/EntityTests.kt @@ -314,7 +314,7 @@ class EntityTests: DatabaseTestsBase() { @Test fun tableSelfReferenceTest() { - assertEquals>(listOf(Categories, Boards, Posts), EntityCache.sortTablesByReferences(listOf(Posts, Boards, Categories))) + assertEquals(listOf(Categories, Boards, Posts), SchemaUtils.sortTablesByReferences(listOf(Posts, Boards, Categories))) } @Test diff --git a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/SelfReferenceTest.kt b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/SelfReferenceTest.kt index b93e6c86bb..ccfdd573f7 100644 --- a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/SelfReferenceTest.kt +++ b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/SelfReferenceTest.kt @@ -1,74 +1,81 @@ package org.jetbrains.exposed.sql.tests.shared -import org.jetbrains.exposed.dao.* +import org.jetbrains.exposed.dao.IntIdTable +import org.jetbrains.exposed.sql.SchemaUtils import org.jetbrains.exposed.sql.Table import org.junit.Test -import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue @Suppress("unused") class SortByReferenceTest { @Test fun simpleTest() { - assertEqualLists(listOf(DMLTestsData.Cities), EntityCache.sortTablesByReferences(listOf(DMLTestsData.Cities))) - assertEqualLists(listOf( DMLTestsData.Cities, DMLTestsData.Users), EntityCache.sortTablesByReferences(listOf(DMLTestsData.Users))) + assertEqualLists(listOf(DMLTestsData.Cities), SchemaUtils.sortTablesByReferences(listOf(DMLTestsData.Cities))) + assertEqualLists(listOf(DMLTestsData.Cities, DMLTestsData.Users), SchemaUtils.sortTablesByReferences(listOf(DMLTestsData.Users))) val rightOrder = listOf(DMLTestsData.Cities, DMLTestsData.Users, DMLTestsData.UserData) - val r1 = EntityCache.sortTablesByReferences(listOf(DMLTestsData.Cities, DMLTestsData.UserData, DMLTestsData.Users)) - val r2 = EntityCache.sortTablesByReferences(listOf(DMLTestsData.UserData, DMLTestsData.Cities, DMLTestsData.Users)) - val r3 = EntityCache.sortTablesByReferences(listOf(DMLTestsData.Users, DMLTestsData.Cities, DMLTestsData.UserData)) - assertEqualLists(r1, rightOrder) - assertEqualLists(r2, rightOrder) - assertEqualLists(r3, rightOrder) + val r1 = SchemaUtils.sortTablesByReferences(listOf(DMLTestsData.Cities, DMLTestsData.UserData, DMLTestsData.Users)) + val r2 = SchemaUtils.sortTablesByReferences(listOf(DMLTestsData.UserData, DMLTestsData.Cities, DMLTestsData.Users)) + val r3 = SchemaUtils.sortTablesByReferences(listOf(DMLTestsData.Users, DMLTestsData.Cities, DMLTestsData.UserData)) + assertEqualLists(rightOrder, r1) + assertEqualLists(rightOrder, r2) + assertEqualLists(rightOrder, r3) } - @Test - fun cycleReferencesCheckTest() { - val cities = object : Table() { + object TestTables { + object cities : Table() { val id = integer("id").autoIncrement().primaryKey() val name = varchar("name", 50) - val strange_id = varchar("strange_id", 10) + val strange_id = varchar("strange_id", 10).references(strangeTable.id) } - val users = object : Table() { + object users : Table() { val id = varchar("id", 10).primaryKey() val name = varchar("name", length = 50) val cityId = (integer("city_id") references cities.id).nullable() } - val noRefereeTable = object: Table() { + object noRefereeTable : Table() { val id = varchar("id", 10).primaryKey() val col1 = varchar("col1", 10) } - val refereeTable = object: Table() { + object refereeTable : Table() { val id = varchar("id", 10).primaryKey() val ref = reference("ref", noRefereeTable.id) } - val referencedTable = object: IntIdTable() { + object referencedTable : IntIdTable() { val col3 = varchar("col3", 10) } - val strangeTable = object : Table() { + object strangeTable : Table() { val id = varchar("id", 10).primaryKey() val user_id = varchar("user_id", 10) references users.id val comment = varchar("comment", 30) val value = integer("value") } + } - with (cities) { - strange_id.references( strangeTable.id) - } + @Test + fun cycleReferencesCheckTest() { + val original = listOf(TestTables.cities, TestTables.users, TestTables.strangeTable, TestTables.noRefereeTable, TestTables.refereeTable, TestTables.referencedTable) + val sortedTables = SchemaUtils.sortTablesByReferences(original) + val expected = listOf(TestTables.users, TestTables.strangeTable, TestTables.cities, TestTables.noRefereeTable, TestTables.refereeTable, TestTables.referencedTable) - val sortedTables = EntityCache.sortTablesByReferences(listOf(cities, users, strangeTable, noRefereeTable, refereeTable, referencedTable)) + assertEqualLists(expected, sortedTables) + } - assert(sortedTables.indexOf(referencedTable) in listOf(0,1)) - assert(sortedTables.indexOf(noRefereeTable) in listOf(0,1)) - assertEquals(2, sortedTables.indexOf(refereeTable)) + @Test + fun testHasCycle() { + assertFalse(SchemaUtils.checkCycle(TestTables.referencedTable)) + assertFalse(SchemaUtils.checkCycle(TestTables.refereeTable)) + assertFalse(SchemaUtils.checkCycle(TestTables.noRefereeTable)) + assertTrue(SchemaUtils.checkCycle(TestTables.users)) + assertTrue(SchemaUtils.checkCycle(TestTables.cities)) + assertTrue(SchemaUtils.checkCycle(TestTables.strangeTable)) - assert(sortedTables.indexOf(cities) in listOf(3,4,5)) - assert(sortedTables.indexOf(users) in listOf(3,4,5)) - assert(sortedTables.indexOf(strangeTable) in listOf(3,4,5)) } } \ No newline at end of file