From f2670fc770433d40073aec6de0205526d4fffc7f Mon Sep 17 00:00:00 2001 From: bog-walk <82039410+bog-walk@users.noreply.github.com> Date: Thu, 15 Aug 2024 22:08:43 -0400 Subject: [PATCH] feat: EXPOSED-486 Support REPLACE INTO ... SELECT clause (#2199) * feat: EXPOSED-486 Support REPLACE INTO ... SELECT clause Existing insert() has an overload that accepts a query argument to generate SQL like INSERT INTO ... SELECT. replace() also exists for supporting databases, but does not have a similar overload. This implements a new statement class that extends InsertSelectStatement and provides a new extension function overload. --- .../Writerside/topics/Deep-Dive-into-DSL.md | 29 ++++++++- exposed-core/api/exposed-core.api | 7 ++ .../org/jetbrains/exposed/sql/ColumnType.kt | 2 +- .../org/jetbrains/exposed/sql/Queries.kt | 24 ++++++- .../sql/statements/ReplaceStatement.kt | 24 +++++++ .../sql/tests/shared/ddl/SequencesTests.kt | 65 ++++++------------- .../sql/tests/shared/dml/ReplaceTests.kt | 53 +++++++++++++++ 7 files changed, 152 insertions(+), 52 deletions(-) diff --git a/documentation-website/Writerside/topics/Deep-Dive-into-DSL.md b/documentation-website/Writerside/topics/Deep-Dive-into-DSL.md index 32cb5e06b4..c9f30ac03b 100644 --- a/documentation-website/Writerside/topics/Deep-Dive-into-DSL.md +++ b/documentation-website/Writerside/topics/Deep-Dive-into-DSL.md @@ -616,15 +616,15 @@ PostgresSQL [here](https://jdbc.postgresql.org/documentation/logging/). ## Insert From Select -If you want to use `INSERT INTO ... SELECT ` SQL clause try Exposed analog `Table.insert(Query)`. +If you want to use the `INSERT INTO ... SELECT ` SQL clause try the function `Table.insert(Query)`: ```kotlin val substring = users.name.substring(1, 2) cities.insert(users.select(substring).orderBy(users.id).limit(2)) ``` -By default it will try to insert into all non auto-increment `Table` columns in order they defined in Table instance. If you want to specify columns or change the -order, provide list of columns as second parameter: +By default, it will try to insert into all non auto-increment `Table` columns in the order they are defined in the `Table` instance. If you want to specify columns or change the +order, provide a list of columns as the second parameter: ```kotlin val userCount = users.selectAll().count() @@ -767,6 +767,29 @@ In the example above, if the original row was inserted with a user-defined +The `REPLACE INTO ... SELECT ` SQL clause can be used by instead providing a query to `replace()`: + +```kotlin +val allRowsWithLowRating: Query = StarWarsFilms.selectAll().where { + StarWarsFilms.rating less 5.0 +} +StarWarsFilms.replace(allRowsWithLowRating) +``` + +By default, it will try to insert into all non auto-increment `Table` columns in the order they are defined in the `Table` instance. +If the columns need to be specified or the order should be changed, provide a list of columns as the second parameter: + +```kotlin +val oneYearLater = StarWarsFilms.releaseYear.plus(1) +val allRowsWithNewYear: Query = StarWarsFilms.select( + oneYearLater, StarWarsFilms.sequelId, StarWarsFilms.director, StarWarsFilms.name +) +StarWarsFilms.replace( + allRowsWithNewYear, + columns = listOf(StarWarsFilms.releaseYear, StarWarsFilms.sequelId, StarWarsFilms.director, StarWarsFilms.name) +) +``` + ## Column transformation Column transformations allow to define custom transformations between database column types and application's data types. diff --git a/exposed-core/api/exposed-core.api b/exposed-core/api/exposed-core.api index 4d7821983b..fa0938895b 100644 --- a/exposed-core/api/exposed-core.api +++ b/exposed-core/api/exposed-core.api @@ -1817,6 +1817,8 @@ public final class org/jetbrains/exposed/sql/QueriesKt { public static final fun mergeFrom (Lorg/jetbrains/exposed/sql/Table;Lorg/jetbrains/exposed/sql/Table;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lorg/jetbrains/exposed/sql/statements/MergeTableStatement; public static synthetic fun mergeFrom$default (Lorg/jetbrains/exposed/sql/Table;Lorg/jetbrains/exposed/sql/Table;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/jetbrains/exposed/sql/statements/MergeTableStatement; public static final fun replace (Lorg/jetbrains/exposed/sql/Table;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/exposed/sql/statements/ReplaceStatement; + public static final fun replace (Lorg/jetbrains/exposed/sql/Table;Lorg/jetbrains/exposed/sql/AbstractQuery;Ljava/util/List;)Ljava/lang/Integer; + public static synthetic fun replace$default (Lorg/jetbrains/exposed/sql/Table;Lorg/jetbrains/exposed/sql/AbstractQuery;Ljava/util/List;ILjava/lang/Object;)Ljava/lang/Integer; public static final fun select (Lorg/jetbrains/exposed/sql/FieldSet;Lkotlin/jvm/functions/Function1;)Lorg/jetbrains/exposed/sql/Query; public static final fun select (Lorg/jetbrains/exposed/sql/FieldSet;Lorg/jetbrains/exposed/sql/Op;)Lorg/jetbrains/exposed/sql/Query; public static final fun select (Lorg/jetbrains/exposed/sql/Query;Lkotlin/jvm/functions/Function1;)Lorg/jetbrains/exposed/sql/Query; @@ -3213,6 +3215,11 @@ public class org/jetbrains/exposed/sql/statements/MergeTableStatement : org/jetb public fun prepareSQL (Lorg/jetbrains/exposed/sql/Transaction;Z)Ljava/lang/String; } +public class org/jetbrains/exposed/sql/statements/ReplaceSelectStatement : org/jetbrains/exposed/sql/statements/InsertSelectStatement { + public fun (Ljava/util/List;Lorg/jetbrains/exposed/sql/AbstractQuery;)V + public fun prepareSQL (Lorg/jetbrains/exposed/sql/Transaction;Z)Ljava/lang/String; +} + public class org/jetbrains/exposed/sql/statements/ReplaceStatement : org/jetbrains/exposed/sql/statements/InsertStatement { public fun (Lorg/jetbrains/exposed/sql/Table;)V public fun prepareSQL (Lorg/jetbrains/exposed/sql/Transaction;Z)Ljava/lang/String; diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt index b081f68211..7034a53702 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt @@ -210,7 +210,7 @@ class AutoIncColumnType( val IColumnType<*>.isAutoInc: Boolean get() = this is AutoIncColumnType || (this is EntityIDColumnType<*> && idColumn.columnType.isAutoInc) -/** Returns the name of the auto-increment sequence of this column. */ +/** Returns this column's type cast as [AutoIncColumnType] or `null` if the cast fails. */ val Column<*>.autoIncColumnType: AutoIncColumnType<*>? get() = (columnType as? AutoIncColumnType) ?: (columnType as? EntityIDColumnType<*>)?.idColumn?.columnType as? AutoIncColumnType diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Queries.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Queries.kt index 57797a7942..be5ba406b8 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Queries.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Queries.kt @@ -357,6 +357,23 @@ fun T.replace(body: T.(UpdateBuilder<*>) -> Unit): ReplaceStatement< execute(TransactionManager.current()) } +/** + * Represents the SQL statement that uses data retrieved from a [selectQuery] to either insert a new row into a table, + * or, if insertion would violate a unique constraint, first delete the existing row before inserting a new row. + * + * **Note:** This operation is not supported by all vendors, please check the documentation. + * + * @param selectQuery Source `SELECT` query that provides the values to insert. + * @param columns Columns to either insert values into or delete values from then insert into. This defaults to all + * columns in the table that are not auto-increment columns without a valid sequence to generate new values. + * @return The number of inserted (and possibly deleted) rows, or `null` if nothing was retrieved after statement execution. + * @sample org.jetbrains.exposed.sql.tests.shared.dml.ReplaceTests.testReplaceSelect + */ +fun T.replace( + selectQuery: AbstractQuery<*>, + columns: List> = this.columns.filter { it.isValidIfAutoIncrement() } +): Int? = ReplaceSelectStatement(columns, selectQuery).execute(TransactionManager.current()) + /** * Represents the SQL statement that uses data retrieved from a [selectQuery] to insert new rows into a table. * @@ -368,7 +385,7 @@ fun T.replace(body: T.(UpdateBuilder<*>) -> Unit): ReplaceStatement< */ fun T.insert( selectQuery: AbstractQuery<*>, - columns: List> = this.columns.filter { !it.columnType.isAutoInc || it.autoIncColumnType?.nextValExpression != null } + columns: List> = this.columns.filter { it.isValidIfAutoIncrement() } ): Int? = InsertSelectStatement(columns, selectQuery).execute(TransactionManager.current()) /** @@ -384,9 +401,12 @@ fun T.insert( */ fun T.insertIgnore( selectQuery: AbstractQuery<*>, - columns: List> = this.columns.filter { !it.columnType.isAutoInc || it.autoIncColumnType?.nextValExpression != null } + columns: List> = this.columns.filter { it.isValidIfAutoIncrement() } ): Int? = InsertSelectStatement(columns, selectQuery, true).execute(TransactionManager.current()) +private fun Column<*>.isValidIfAutoIncrement(): Boolean = + !columnType.isAutoInc || autoIncColumnType?.nextValExpression != null + /** * Represents the SQL statement that inserts new rows into a table and returns specified data from the inserted rows. * diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/ReplaceStatement.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/ReplaceStatement.kt index 095920107f..88dc6f69dd 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/ReplaceStatement.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/ReplaceStatement.kt @@ -1,5 +1,7 @@ package org.jetbrains.exposed.sql.statements +import org.jetbrains.exposed.sql.AbstractQuery +import org.jetbrains.exposed.sql.Column import org.jetbrains.exposed.sql.Table import org.jetbrains.exposed.sql.Transaction import org.jetbrains.exposed.sql.vendors.H2Dialect @@ -24,3 +26,25 @@ open class ReplaceStatement(table: Table) : InsertStatement(tabl return functionProvider.replace(table, values.unzip().first, valuesSql, transaction, prepared) } } + +/** + * Represents the SQL statement that uses data retrieved from a [selectQuery] to either insert a new row into a table, + * or, if insertion would violate a unique constraint, first delete the existing row before inserting a new row. + * + * @param columns Columns to either insert values into or delete values from then insert into. + * @param selectQuery Source SELECT query that provides the values to insert. + */ +open class ReplaceSelectStatement( + columns: List>, + selectQuery: AbstractQuery<*> +) : InsertSelectStatement(columns, selectQuery) { + override fun prepareSQL(transaction: Transaction, prepared: Boolean): String { + val querySql = selectQuery.prepareSQL(transaction, prepared) + val dialect = transaction.db.dialect + val functionProvider = when (dialect.h2Mode) { + H2Dialect.H2CompatibilityMode.MySQL, H2Dialect.H2CompatibilityMode.MariaDB -> MysqlFunctionProvider() + else -> dialect.functionProvider + } + return functionProvider.replace(targets.single(), columns, querySql, transaction, prepared) + } +} diff --git a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/ddl/SequencesTests.kt b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/ddl/SequencesTests.kt index 9859ad8f50..28ad4c3db3 100644 --- a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/ddl/SequencesTests.kt +++ b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/ddl/SequencesTests.kt @@ -63,19 +63,10 @@ class SequencesTests : DatabaseTestsBase() { } @Test - fun `testInsertWithCustomSequence`() { - val customSequence = Sequence( - name = "my_sequence", - startWith = 4, - incrementBy = 2, - minValue = 1, - maxValue = 100, - cycle = true, - cache = 20 - ) + fun testInsertWithCustomSequence() { val tester = object : Table("tester") { - val id = integer("id").autoIncrement(customSequence) - var name = varchar("name", 25) + val id = integer("id").autoIncrement(myseq) + val name = varchar("name", 25) override val primaryKey = PrimaryKey(id, name) } @@ -83,22 +74,22 @@ class SequencesTests : DatabaseTestsBase() { if (currentDialectTest.supportsSequenceAsGeneratedKeys) { try { SchemaUtils.create(tester) - assertTrue(customSequence.exists()) + assertTrue(myseq.exists()) var testerId = tester.insert { it[name] = "Hichem" } get tester.id - assertEquals(customSequence.startWith, testerId.toLong()) + assertEquals(myseq.startWith, testerId.toLong()) testerId = tester.insert { it[name] = "Andrey" } get tester.id - assertEquals(customSequence.startWith!! + customSequence.incrementBy!!, testerId.toLong()) + assertEquals(myseq.startWith!! + myseq.incrementBy!!, testerId.toLong()) } finally { SchemaUtils.drop(tester) - assertFalse(customSequence.exists()) + assertFalse(myseq.exists()) } } } @@ -131,19 +122,10 @@ class SequencesTests : DatabaseTestsBase() { } @Test - fun `testInsertInIdTableWithCustomSequence`() { - val customSequence = Sequence( - name = "my_sequence", - startWith = 4, - incrementBy = 2, - minValue = 1, - maxValue = 100, - cycle = true, - cache = 20 - ) + fun testInsertInIdTableWithCustomSequence() { val tester = object : IdTable("tester") { - override val id = long("id").autoIncrement(customSequence).entityId() - var name = varchar("name", 25) + override val id = long("id").autoIncrement(myseq).entityId() + val name = varchar("name", 25) override val primaryKey = PrimaryKey(id, name) } @@ -151,22 +133,22 @@ class SequencesTests : DatabaseTestsBase() { if (currentDialectTest.supportsSequenceAsGeneratedKeys) { try { SchemaUtils.create(tester) - assertTrue(customSequence.exists()) + assertTrue(myseq.exists()) var testerId = tester.insert { it[name] = "Hichem" } get tester.id - assertEquals(customSequence.startWith, testerId.value) + assertEquals(myseq.startWith, testerId.value) testerId = tester.insert { it[name] = "Andrey" } get tester.id - assertEquals(customSequence.startWith!! + customSequence.incrementBy!!, testerId.value) + assertEquals(myseq.startWith!! + myseq.incrementBy!!, testerId.value) } finally { SchemaUtils.drop(tester) - assertFalse(customSequence.exists()) + assertFalse(myseq.exists()) } } } @@ -239,17 +221,8 @@ class SequencesTests : DatabaseTestsBase() { @Test fun testExistingSequencesForAutoIncrementWithCustomSequence() { - val customSequence = Sequence( - name = "my_sequence", - startWith = 4, - incrementBy = 2, - minValue = 1, - maxValue = 100, - cycle = true, - cache = 20 - ) val tableWithExplicitSequenceName = object : IdTable() { - override val id: Column> = long("id").autoIncrement(customSequence).entityId() + override val id: Column> = long("id").autoIncrement(myseq).entityId() } withDb { @@ -260,7 +233,7 @@ class SequencesTests : DatabaseTestsBase() { val sequences = currentDialectTest.sequences() assertTrue(sequences.isNotEmpty()) - assertTrue(sequences.any { it == customSequence.name.inProperCase() }) + assertTrue(sequences.any { it == myseq.name.inProperCase() }) } finally { SchemaUtils.drop(tableWithExplicitSequenceName) } @@ -340,18 +313,18 @@ class SequencesTests : DatabaseTestsBase() { private object Developer : Table() { val id = integer("id") - var name = varchar("name", 25) + val name = varchar("name", 25) override val primaryKey = PrimaryKey(id, name) } private object DeveloperWithLongId : LongIdTable() { - var name = varchar("name", 25) + val name = varchar("name", 25) } private object DeveloperWithAutoIncrementBySequence : IdTable() { override val id: Column> = long("id").autoIncrement("id_seq").entityId() - var name = varchar("name", 25) + val name = varchar("name", 25) } private val myseq = Sequence( diff --git a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReplaceTests.kt b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReplaceTests.kt index bf269bca33..4369ed08d5 100644 --- a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReplaceTests.kt +++ b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReplaceTests.kt @@ -8,6 +8,7 @@ import org.jetbrains.exposed.sql.tests.shared.assertEquals import org.jetbrains.exposed.sql.tests.shared.assertTrue import org.junit.Test import java.util.* +import kotlin.test.assertContentEquals class ReplaceTests : DatabaseTestsBase() { @@ -22,6 +23,58 @@ class ReplaceTests : DatabaseTestsBase() { override val primaryKey = PrimaryKey(username) } + @Test + fun testReplaceSelect() { + withTables(replaceNotSupported, NewAuth) { testDb -> + NewAuth.batchReplace(listOf("username1", "username2")) { // inserts 2 new non-conflict rows with defaults + this[NewAuth.username] = it + this[NewAuth.session] = "session".toByteArray() + } + + val result1 = NewAuth.selectAll().toList() + assertTrue(result1.all { it[NewAuth.timestamp] == 0L }) + assertTrue(result1.all { it[NewAuth.serverID].isEmpty() }) + + val timeNow = System.currentTimeMillis() + val specialId = "special server id" + val allRowsWithNewDefaults = NewAuth.select(NewAuth.username, NewAuth.session, longLiteral(timeNow), stringLiteral(specialId)) + + val affectedRowCount = NewAuth.replace(allRowsWithNewDefaults) + // MySQL returns 1 for every insert + 1 for every delete on conflict, while others only count inserts + val expectedRowCount = if (testDb in TestDB.ALL_MYSQL_LIKE) 4 else 2 + assertEquals(expectedRowCount, affectedRowCount) + + val result2 = NewAuth.selectAll().toList() + assertTrue(result2.all { it[NewAuth.timestamp] == timeNow }) + assertTrue(result2.all { it[NewAuth.serverID] == specialId }) + } + } + + @Test + fun testReplaceSelectWithSpecificColumns() { + withTables(replaceNotSupported, NewAuth) { testDb -> + val (name1, name2, oldSession) = Triple("username1", "username2", "session1".toByteArray()) + NewAuth.batchReplace(listOf(name1, name2)) { // inserts 2 new non-conflict rows with defaults + this[NewAuth.username] = it + this[NewAuth.session] = oldSession + } + + val newSession = "session2" + val name1Row = NewAuth.select(NewAuth.username, stringLiteral(newSession)).where { NewAuth.username eq name1 } + + val affectedRowCount = NewAuth.replace(name1Row, columns = listOf(NewAuth.username, NewAuth.session)) + // MySQL returns 1 for every insert + 1 for every delete on conflict, while others only count inserts + val expectedRowCount = if (testDb in TestDB.ALL_MYSQL_LIKE) 2 else 1 + assertEquals(expectedRowCount, affectedRowCount) + + val name1Result = NewAuth.selectAll().where { NewAuth.username eq name1 }.single() + assertContentEquals(newSession.toByteArray(), name1Result[NewAuth.session]) + + val name2Result = NewAuth.selectAll().where { NewAuth.username eq name2 }.single() + assertContentEquals(oldSession, name2Result[NewAuth.session]) + } + } + @Test fun testReplaceWithPKConflict() { withTables(replaceNotSupported, NewAuth) {