diff --git a/src/main/kotlin/org/jetbrains/exposed/dao/Entity.kt b/src/main/kotlin/org/jetbrains/exposed/dao/Entity.kt
index 13fb7b15b1..8c5df410c0 100644
--- a/src/main/kotlin/org/jetbrains/exposed/dao/Entity.kt
+++ b/src/main/kotlin/org/jetbrains/exposed/dao/Entity.kt
@@ -380,7 +380,7 @@ class EntityCache {
val insertedTables = inserts.keys
- for (t in sortTablesByReferences(tables)) {
+ for (t in SchemaUtils.sortTablesByReferences(tables)) {
flushInserts(t as IdTable<*>)
}
@@ -466,36 +466,6 @@ class EntityCache {
it.expireCache()
}
}
-
- fun sortTablesByReferences(tables: Iterable
) = addDependencies(tables).toCollection(arrayListOf()).run {
- if(this.count() <= 1) return this
- val canBeReferenced = arrayListOf()
- do {
- val (movable, others) = partition {
- it.columns.all { it.referee == null || canBeReferenced.contains(it.referee!!.table) || it.referee!!.table == it.table}
- }
- canBeReferenced.addAll(movable)
- this.removeAll(movable)
- } while (others.isNotEmpty() && movable.isNotEmpty())
- canBeReferenced.addAll(this)
- canBeReferenced
- }
-
- fun addDependencies(tables: Iterable): Iterable {
- val workset = HashSet()
-
- fun checkTable(table: Table) {
- if (workset.add(table)) {
- for (c in table.columns) {
- c.referee?.table?.let { checkTable(it) }
- }
- }
- }
-
- for (t in tables) checkTable(t)
-
- return workset
- }
}
}
diff --git a/src/main/kotlin/org/jetbrains/exposed/sql/Constraints.kt b/src/main/kotlin/org/jetbrains/exposed/sql/Constraints.kt
index 672b187e6c..d075238028 100644
--- a/src/main/kotlin/org/jetbrains/exposed/sql/Constraints.kt
+++ b/src/main/kotlin/org/jetbrains/exposed/sql/Constraints.kt
@@ -38,9 +38,11 @@ data class ForeignKeyConstraint(val fkName: String, val refereeTable: String, va
companion object {
fun from(column: Column<*>): ForeignKeyConstraint {
assert(column.referee != null && (column.onDelete != null || column.onUpdate != null)) { "$column does not reference anything" }
- val s = TransactionManager.current()
- return ForeignKeyConstraint("", s.identity(column.referee!!.table),
- s.identity(column.referee!!), s.identity(column.table), s.identity(column),
+ val referee = column.referee!!
+ val t = TransactionManager.current()
+ val refName = t.quoteIfNecessary(t.cutIfNecessary("fk_${referee.table.tableName}_${referee.name}_${column.name}"))
+ return ForeignKeyConstraint(refName, t.identity(referee.table), t.identity(referee),
+ t.identity(column.table), t.identity(column),
column.onUpdate ?: ReferenceOption.NO_ACTION,
column.onDelete ?: ReferenceOption.NO_ACTION)
}
@@ -56,7 +58,7 @@ data class ForeignKeyConstraint(val fkName: String, val refereeTable: String, va
}
}
- override fun createStatement() = listOf("ALTER TABLE $referencedTable ADD" + if (fkName.isNotBlank()) " CONSTRAINT $fkName" else "" + foreignKeyPart)
+ override fun createStatement() = listOf("ALTER TABLE $referencedTable ADD" + (if (fkName.isNotBlank()) " CONSTRAINT $fkName" else "") + foreignKeyPart)
override fun dropStatement() = listOf("ALTER TABLE $refereeTable DROP " +
when (currentDialect) {
diff --git a/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt b/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt
index e190ca58af..d9b3478614 100644
--- a/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt
+++ b/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt
@@ -1,33 +1,89 @@
package org.jetbrains.exposed.sql
-import org.jetbrains.exposed.dao.EntityCache
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.jetbrains.exposed.sql.vendors.currentDialect
import org.jetbrains.exposed.sql.vendors.inProperCase
import java.util.*
object SchemaUtils {
- fun createStatements(vararg tables: Table): List {
- val statements = ArrayList()
- if (tables.isEmpty())
- return statements
+ private class TableDepthGraph(val tables: List) {
+ val graph = fetchAllTables().associate { t ->
+ t to t.columns.mapNotNull { c ->
+ c.referee?.let{ it.table to c.columnType.nullable }
+ }.toMap()
+ }
+
+ private fun fetchAllTables(): HashSet {
+ val result = HashSet()
+
+ fun parseTable(table: Table) {
+ if (result.add(table)) {
+ table.columns.forEach {
+ it.referee?.table?.let(::parseTable)
+ }
+ }
+ }
+ tables.forEach(::parseTable)
+ return result
+ }
+
+ fun sorted() : List {
+ val visited = mutableSetOf()
+ val result = arrayListOf()
- val newTables = ArrayList()
+ fun traverse(table: Table) {
+ if (table !in visited) {
+ visited += table
+ graph[table]!!.forEach { t, u ->
+ if (t !in visited) {
+ traverse(t)
+ }
+ }
+ result += table
+ }
+ }
- for (table in EntityCache.sortTablesByReferences(tables.toList())) {
+ tables.forEach(::traverse)
+ return result
+ }
- if (table.exists()) continue else newTables.add(table)
+ fun hasCycle() : Boolean {
+ val visited = mutableSetOf()
+ val recursion = mutableSetOf()
- // create table
- statements.addAll(table.ddl)
+ val sortedTables = sorted()
- // create indices
- for (index in table.indices) {
- statements.addAll(createIndex(index))
+ fun traverse(table: Table) : Boolean {
+ if (table in recursion) return true
+ if (table in visited) return false
+ recursion += table
+ visited += table
+ return if (graph[table]!!.any{ traverse(it.key) }) {
+ true
+ } else {
+ recursion -= table
+ false
+ }
}
+ return sortedTables.any { traverse(it) }
}
+ }
+
+ fun sortTablesByReferences(tables: Iterable) = TableDepthGraph(tables.toList()).sorted()
+ fun checkCycle(vararg tables: Table) = TableDepthGraph(tables.toList()).hasCycle()
+
+ fun createStatements(vararg tables: Table): List {
+ if (tables.isEmpty())
+ return emptyList()
- return statements
+ val toCreate = sortTablesByReferences(tables.toList()).filterNot { it.exists() }
+ val alters = arrayListOf()
+ return toCreate.flatMap { table ->
+ val (create, alter) = table.ddl.partition { it.startsWith("CREATE ") }
+ val indicesDDL = table.indices.flatMap { createIndex(it) }
+ alters += alter
+ create + indicesDDL
+ } + alters
}
fun createSequence(name: String) = Seq(name).createStatement()
@@ -185,7 +241,7 @@ object SchemaUtils {
if (tables.isEmpty()) return
val transaction = TransactionManager.current()
transaction.flushCache()
- var tablesForDeletion = EntityCache
+ var tablesForDeletion = SchemaUtils
.sortTablesByReferences(tables.toList())
.reversed()
.filter { it in tables }
diff --git a/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt b/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
index a9bba7ce43..f482e740d5 100644
--- a/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
+++ b/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
@@ -3,10 +3,7 @@ package org.jetbrains.exposed.sql
import org.jetbrains.exposed.dao.EntityID
import org.jetbrains.exposed.dao.IdTable
import org.jetbrains.exposed.sql.transactions.TransactionManager
-import org.jetbrains.exposed.sql.vendors.OracleDialect
-import org.jetbrains.exposed.sql.vendors.currentDialect
-import org.jetbrains.exposed.sql.vendors.currentDialectIfAvailable
-import org.jetbrains.exposed.sql.vendors.inProperCase
+import org.jetbrains.exposed.sql.vendors.*
import org.joda.time.DateTime
import java.math.BigDecimal
import java.sql.Blob
@@ -485,6 +482,8 @@ open class Table(name: String = ""): ColumnSet(), DdlAware {
Seq(it).createStatement()
}.orEmpty()
+ val addForeignKeysInAlterPart = SchemaUtils.checkCycle(this) && currentDialect !is SQLiteDialect
+
val createTableDDL = buildString {
append("CREATE TABLE ")
if (currentDialect.supportsIfNotExists) {
@@ -498,24 +497,31 @@ open class Table(name: String = ""): ColumnSet(), DdlAware {
append(", $it")
}
}
- columns.filter { it.referee != null }.let { references ->
- if (references.isNotEmpty()) {
- append(references.joinToString(prefix = ", ", separator = ", ") { ForeignKeyConstraint.from(it).foreignKeyPart })
+
+ if (!addForeignKeysInAlterPart) {
+ columns.filter { it.referee != null }.takeIf { it.isNotEmpty() }?.let { references ->
+ references.joinTo(this, prefix = ", ", separator = ", ") { ForeignKeyConstraint.from(it).foreignKeyPart }
}
}
+
if (checkConstraints.isNotEmpty()) {
- append(
- checkConstraints.mapIndexed { index, (name, op) ->
- val resolvedName = name.takeIf { it.isNotBlank() } ?: "check_${tableName}_$index"
- CheckConstraint.from(this@Table, resolvedName, op).checkPart
- }.joinToString(prefix = ",", separator = ",")
- )
+ checkConstraints.asSequence().mapIndexed { index, (name, op) ->
+ val resolvedName = name.takeIf { it.isNotBlank() } ?: "check_${tableName}_$index"
+ CheckConstraint.from(this@Table, resolvedName, op).checkPart
+ }.joinTo(this, prefix = ",", separator = ",")
}
append(")")
}
}
- return seqDDL + createTableDDL
+
+ val constraintSQL = if (addForeignKeysInAlterPart) {
+ columns.filter { it.referee != null }.flatMap {
+ ForeignKeyConstraint.from(it).createStatement()
+ }
+ } else emptyList()
+
+ return seqDDL + createTableDDL + constraintSQL
}
internal fun primaryKeyConstraint(): String? {
@@ -538,12 +544,16 @@ open class Table(name: String = ""): ColumnSet(), DdlAware {
val dropTableDDL = buildString {
append("DROP TABLE ")
if (currentDialect.supportsIfNotExists) {
- append(" IF EXISTS ")
+ append("IF EXISTS ")
}
append(TransactionManager.current().identity(this@Table))
if (currentDialectIfAvailable is OracleDialect) {
append(" CASCADE CONSTRAINTS")
}
+
+ if (currentDialectIfAvailable is PostgreSQLDialect && SchemaUtils.checkCycle(this@Table)) {
+ append(" CASCADE")
+ }
}
val seqDDL = autoIncColumn?.autoIncSeqName?.let {
Seq(it).dropStatement()
diff --git a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/DDLTests.kt b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/DDLTests.kt
index 8e8ebbc845..6f9963d0dd 100644
--- a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/DDLTests.kt
+++ b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/DDLTests.kt
@@ -433,7 +433,7 @@ class DDLTests : DatabaseTestsBase() {
}
}
- withTables(excludeSettings = listOf(TestDB.H2, TestDB.H2_MYSQL, TestDB.SQLITE), tables = initialTable) {
+ withTables(excludeSettings = listOf(TestDB.H2, TestDB.H2_MYSQL, TestDB.SQLITE), tables = *arrayOf(initialTable)) {
assertEquals("ALTER TABLE ${tableName.inProperCase()} ADD ${"id".inProperCase()} ${t.id.columnType.sqlType()} PRIMARY KEY", t.id.ddl)
assertEquals(1, currentDialect.tableColumns(t)[t]!!.size)
SchemaUtils.createMissingTablesAndColumns(t)
@@ -489,6 +489,36 @@ class DDLTests : DatabaseTestsBase() {
}
}
+ object Table1 : IntIdTable() {
+ val table2 = reference("teamId", Table2, onDelete = ReferenceOption.CASCADE)
+ }
+
+ object Table2 : IntIdTable() {
+ val table1 = optReference("teamId", Table1, onDelete = ReferenceOption.SET_NULL)
+ }
+
+ @Test fun testCrossReference() {
+ withTables(Table2, Table1) {
+ val table2id = Table2.insertAndGetId{}
+ val table1id = Table1.insertAndGetId {
+ it[Table1.table2] = table2id
+ }
+
+ Table2.insertAndGetId {
+ it[Table2.table1] = table1id
+ }
+
+ assertEquals(1, Table1.selectAll().count())
+ assertEquals(2, Table2.selectAll().count())
+ if (currentDialect is MysqlDialect) {
+ exec("SET foreign_key_checks = 0;")
+ }
+ if (currentDialect is PostgreSQLDialect) {
+ exec("set constraints all deferred;")
+ }
+ }
+ }
+
@Test fun testUUIDColumnType() {
val Node = object: Table("node") {
val uuid = uuid("uuid").primaryKey()
diff --git a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/EntityTests.kt b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/EntityTests.kt
index 2099fb1d77..1821a857c3 100644
--- a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/EntityTests.kt
+++ b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/EntityTests.kt
@@ -314,7 +314,7 @@ class EntityTests: DatabaseTestsBase() {
@Test
fun tableSelfReferenceTest() {
- assertEquals>(listOf(Categories, Boards, Posts), EntityCache.sortTablesByReferences(listOf(Posts, Boards, Categories)))
+ assertEquals(listOf(Categories, Boards, Posts), SchemaUtils.sortTablesByReferences(listOf(Posts, Boards, Categories)))
}
@Test
diff --git a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/SelfReferenceTest.kt b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/SelfReferenceTest.kt
index b93e6c86bb..ccfdd573f7 100644
--- a/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/SelfReferenceTest.kt
+++ b/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/SelfReferenceTest.kt
@@ -1,74 +1,81 @@
package org.jetbrains.exposed.sql.tests.shared
-import org.jetbrains.exposed.dao.*
+import org.jetbrains.exposed.dao.IntIdTable
+import org.jetbrains.exposed.sql.SchemaUtils
import org.jetbrains.exposed.sql.Table
import org.junit.Test
-import kotlin.test.assertEquals
+import kotlin.test.assertFalse
+import kotlin.test.assertTrue
@Suppress("unused")
class SortByReferenceTest {
@Test
fun simpleTest() {
- assertEqualLists(listOf(DMLTestsData.Cities), EntityCache.sortTablesByReferences(listOf(DMLTestsData.Cities)))
- assertEqualLists(listOf( DMLTestsData.Cities, DMLTestsData.Users), EntityCache.sortTablesByReferences(listOf(DMLTestsData.Users)))
+ assertEqualLists(listOf(DMLTestsData.Cities), SchemaUtils.sortTablesByReferences(listOf(DMLTestsData.Cities)))
+ assertEqualLists(listOf(DMLTestsData.Cities, DMLTestsData.Users), SchemaUtils.sortTablesByReferences(listOf(DMLTestsData.Users)))
val rightOrder = listOf(DMLTestsData.Cities, DMLTestsData.Users, DMLTestsData.UserData)
- val r1 = EntityCache.sortTablesByReferences(listOf(DMLTestsData.Cities, DMLTestsData.UserData, DMLTestsData.Users))
- val r2 = EntityCache.sortTablesByReferences(listOf(DMLTestsData.UserData, DMLTestsData.Cities, DMLTestsData.Users))
- val r3 = EntityCache.sortTablesByReferences(listOf(DMLTestsData.Users, DMLTestsData.Cities, DMLTestsData.UserData))
- assertEqualLists(r1, rightOrder)
- assertEqualLists(r2, rightOrder)
- assertEqualLists(r3, rightOrder)
+ val r1 = SchemaUtils.sortTablesByReferences(listOf(DMLTestsData.Cities, DMLTestsData.UserData, DMLTestsData.Users))
+ val r2 = SchemaUtils.sortTablesByReferences(listOf(DMLTestsData.UserData, DMLTestsData.Cities, DMLTestsData.Users))
+ val r3 = SchemaUtils.sortTablesByReferences(listOf(DMLTestsData.Users, DMLTestsData.Cities, DMLTestsData.UserData))
+ assertEqualLists(rightOrder, r1)
+ assertEqualLists(rightOrder, r2)
+ assertEqualLists(rightOrder, r3)
}
- @Test
- fun cycleReferencesCheckTest() {
- val cities = object : Table() {
+ object TestTables {
+ object cities : Table() {
val id = integer("id").autoIncrement().primaryKey()
val name = varchar("name", 50)
- val strange_id = varchar("strange_id", 10)
+ val strange_id = varchar("strange_id", 10).references(strangeTable.id)
}
- val users = object : Table() {
+ object users : Table() {
val id = varchar("id", 10).primaryKey()
val name = varchar("name", length = 50)
val cityId = (integer("city_id") references cities.id).nullable()
}
- val noRefereeTable = object: Table() {
+ object noRefereeTable : Table() {
val id = varchar("id", 10).primaryKey()
val col1 = varchar("col1", 10)
}
- val refereeTable = object: Table() {
+ object refereeTable : Table() {
val id = varchar("id", 10).primaryKey()
val ref = reference("ref", noRefereeTable.id)
}
- val referencedTable = object: IntIdTable() {
+ object referencedTable : IntIdTable() {
val col3 = varchar("col3", 10)
}
- val strangeTable = object : Table() {
+ object strangeTable : Table() {
val id = varchar("id", 10).primaryKey()
val user_id = varchar("user_id", 10) references users.id
val comment = varchar("comment", 30)
val value = integer("value")
}
+ }
- with (cities) {
- strange_id.references( strangeTable.id)
- }
+ @Test
+ fun cycleReferencesCheckTest() {
+ val original = listOf(TestTables.cities, TestTables.users, TestTables.strangeTable, TestTables.noRefereeTable, TestTables.refereeTable, TestTables.referencedTable)
+ val sortedTables = SchemaUtils.sortTablesByReferences(original)
+ val expected = listOf(TestTables.users, TestTables.strangeTable, TestTables.cities, TestTables.noRefereeTable, TestTables.refereeTable, TestTables.referencedTable)
- val sortedTables = EntityCache.sortTablesByReferences(listOf(cities, users, strangeTable, noRefereeTable, refereeTable, referencedTable))
+ assertEqualLists(expected, sortedTables)
+ }
- assert(sortedTables.indexOf(referencedTable) in listOf(0,1))
- assert(sortedTables.indexOf(noRefereeTable) in listOf(0,1))
- assertEquals(2, sortedTables.indexOf(refereeTable))
+ @Test
+ fun testHasCycle() {
+ assertFalse(SchemaUtils.checkCycle(TestTables.referencedTable))
+ assertFalse(SchemaUtils.checkCycle(TestTables.refereeTable))
+ assertFalse(SchemaUtils.checkCycle(TestTables.noRefereeTable))
+ assertTrue(SchemaUtils.checkCycle(TestTables.users))
+ assertTrue(SchemaUtils.checkCycle(TestTables.cities))
+ assertTrue(SchemaUtils.checkCycle(TestTables.strangeTable))
- assert(sortedTables.indexOf(cities) in listOf(3,4,5))
- assert(sortedTables.indexOf(users) in listOf(3,4,5))
- assert(sortedTables.indexOf(strangeTable) in listOf(3,4,5))
}
}
\ No newline at end of file