Skip to content

Commit

Permalink
feat: EXPOSED-486 Support REPLACE INTO ... SELECT clause (#2199)
Browse files Browse the repository at this point in the history
* feat: EXPOSED-486 Support REPLACE INTO ... SELECT clause

Existing insert() has an overload that accepts a query argument to generate SQL like
INSERT INTO ... SELECT. replace() also exists for supporting databases, but does
not have a similar overload.

This implements a new statement class that extends InsertSelectStatement and provides
a new extension function overload.
  • Loading branch information
bog-walk authored Aug 16, 2024
1 parent 6f2019d commit f2670fc
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 52 deletions.
29 changes: 26 additions & 3 deletions documentation-website/Writerside/topics/Deep-Dive-into-DSL.md
Original file line number Diff line number Diff line change
Expand Up @@ -616,15 +616,15 @@ PostgresSQL [here](https://jdbc.postgresql.org/documentation/logging/).

## Insert From Select

If you want to use `INSERT INTO ... SELECT ` SQL clause try Exposed analog `Table.insert(Query)`.
If you want to use the `INSERT INTO ... SELECT ` SQL clause try the function `Table.insert(Query)`:

```kotlin
val substring = users.name.substring(1, 2)
cities.insert(users.select(substring).orderBy(users.id).limit(2))
```

By default it will try to insert into all non auto-increment `Table` columns in order they defined in Table instance. If you want to specify columns or change the
order, provide list of columns as second parameter:
By default, it will try to insert into all non auto-increment `Table` columns in the order they are defined in the `Table` instance. If you want to specify columns or change the
order, provide a list of columns as the second parameter:

```kotlin
val userCount = users.selectAll().count()
Expand Down Expand Up @@ -767,6 +767,29 @@ In the example above, if the original row was inserted with a user-defined <code
the newly inserted row would store the default rating value. This is because the old row was completely deleted first.
</note>

The `REPLACE INTO ... SELECT ` SQL clause can be used by instead providing a query to `replace()`:

```kotlin
val allRowsWithLowRating: Query = StarWarsFilms.selectAll().where {
StarWarsFilms.rating less 5.0
}
StarWarsFilms.replace(allRowsWithLowRating)
```

By default, it will try to insert into all non auto-increment `Table` columns in the order they are defined in the `Table` instance.
If the columns need to be specified or the order should be changed, provide a list of columns as the second parameter:

```kotlin
val oneYearLater = StarWarsFilms.releaseYear.plus(1)
val allRowsWithNewYear: Query = StarWarsFilms.select(
oneYearLater, StarWarsFilms.sequelId, StarWarsFilms.director, StarWarsFilms.name
)
StarWarsFilms.replace(
allRowsWithNewYear,
columns = listOf(StarWarsFilms.releaseYear, StarWarsFilms.sequelId, StarWarsFilms.director, StarWarsFilms.name)
)
```

## Column transformation

Column transformations allow to define custom transformations between database column types and application's data types.
Expand Down
7 changes: 7 additions & 0 deletions exposed-core/api/exposed-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,8 @@ public final class org/jetbrains/exposed/sql/QueriesKt {
public static final fun mergeFrom (Lorg/jetbrains/exposed/sql/Table;Lorg/jetbrains/exposed/sql/Table;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lorg/jetbrains/exposed/sql/statements/MergeTableStatement;
public static synthetic fun mergeFrom$default (Lorg/jetbrains/exposed/sql/Table;Lorg/jetbrains/exposed/sql/Table;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/jetbrains/exposed/sql/statements/MergeTableStatement;
public static final fun replace (Lorg/jetbrains/exposed/sql/Table;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/exposed/sql/statements/ReplaceStatement;
public static final fun replace (Lorg/jetbrains/exposed/sql/Table;Lorg/jetbrains/exposed/sql/AbstractQuery;Ljava/util/List;)Ljava/lang/Integer;
public static synthetic fun replace$default (Lorg/jetbrains/exposed/sql/Table;Lorg/jetbrains/exposed/sql/AbstractQuery;Ljava/util/List;ILjava/lang/Object;)Ljava/lang/Integer;
public static final fun select (Lorg/jetbrains/exposed/sql/FieldSet;Lkotlin/jvm/functions/Function1;)Lorg/jetbrains/exposed/sql/Query;
public static final fun select (Lorg/jetbrains/exposed/sql/FieldSet;Lorg/jetbrains/exposed/sql/Op;)Lorg/jetbrains/exposed/sql/Query;
public static final fun select (Lorg/jetbrains/exposed/sql/Query;Lkotlin/jvm/functions/Function1;)Lorg/jetbrains/exposed/sql/Query;
Expand Down Expand Up @@ -3213,6 +3215,11 @@ public class org/jetbrains/exposed/sql/statements/MergeTableStatement : org/jetb
public fun prepareSQL (Lorg/jetbrains/exposed/sql/Transaction;Z)Ljava/lang/String;
}

public class org/jetbrains/exposed/sql/statements/ReplaceSelectStatement : org/jetbrains/exposed/sql/statements/InsertSelectStatement {
public fun <init> (Ljava/util/List;Lorg/jetbrains/exposed/sql/AbstractQuery;)V
public fun prepareSQL (Lorg/jetbrains/exposed/sql/Transaction;Z)Ljava/lang/String;
}

public class org/jetbrains/exposed/sql/statements/ReplaceStatement : org/jetbrains/exposed/sql/statements/InsertStatement {
public fun <init> (Lorg/jetbrains/exposed/sql/Table;)V
public fun prepareSQL (Lorg/jetbrains/exposed/sql/Transaction;Z)Ljava/lang/String;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class AutoIncColumnType<T>(
val IColumnType<*>.isAutoInc: Boolean
get() = this is AutoIncColumnType || (this is EntityIDColumnType<*> && idColumn.columnType.isAutoInc)

/** Returns the name of the auto-increment sequence of this column. */
/** Returns this column's type cast as [AutoIncColumnType] or `null` if the cast fails. */
val Column<*>.autoIncColumnType: AutoIncColumnType<*>?
get() = (columnType as? AutoIncColumnType)
?: (columnType as? EntityIDColumnType<*>)?.idColumn?.columnType as? AutoIncColumnType
Expand Down
24 changes: 22 additions & 2 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Queries.kt
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,23 @@ fun <T : Table> T.replace(body: T.(UpdateBuilder<*>) -> Unit): ReplaceStatement<
execute(TransactionManager.current())
}

/**
* Represents the SQL statement that uses data retrieved from a [selectQuery] to either insert a new row into a table,
* or, if insertion would violate a unique constraint, first delete the existing row before inserting a new row.
*
* **Note:** This operation is not supported by all vendors, please check the documentation.
*
* @param selectQuery Source `SELECT` query that provides the values to insert.
* @param columns Columns to either insert values into or delete values from then insert into. This defaults to all
* columns in the table that are not auto-increment columns without a valid sequence to generate new values.
* @return The number of inserted (and possibly deleted) rows, or `null` if nothing was retrieved after statement execution.
* @sample org.jetbrains.exposed.sql.tests.shared.dml.ReplaceTests.testReplaceSelect
*/
fun <T : Table> T.replace(
selectQuery: AbstractQuery<*>,
columns: List<Column<*>> = this.columns.filter { it.isValidIfAutoIncrement() }
): Int? = ReplaceSelectStatement(columns, selectQuery).execute(TransactionManager.current())

/**
* Represents the SQL statement that uses data retrieved from a [selectQuery] to insert new rows into a table.
*
Expand All @@ -368,7 +385,7 @@ fun <T : Table> T.replace(body: T.(UpdateBuilder<*>) -> Unit): ReplaceStatement<
*/
fun <T : Table> T.insert(
selectQuery: AbstractQuery<*>,
columns: List<Column<*>> = this.columns.filter { !it.columnType.isAutoInc || it.autoIncColumnType?.nextValExpression != null }
columns: List<Column<*>> = this.columns.filter { it.isValidIfAutoIncrement() }
): Int? = InsertSelectStatement(columns, selectQuery).execute(TransactionManager.current())

/**
Expand All @@ -384,9 +401,12 @@ fun <T : Table> T.insert(
*/
fun <T : Table> T.insertIgnore(
selectQuery: AbstractQuery<*>,
columns: List<Column<*>> = this.columns.filter { !it.columnType.isAutoInc || it.autoIncColumnType?.nextValExpression != null }
columns: List<Column<*>> = this.columns.filter { it.isValidIfAutoIncrement() }
): Int? = InsertSelectStatement(columns, selectQuery, true).execute(TransactionManager.current())

private fun Column<*>.isValidIfAutoIncrement(): Boolean =
!columnType.isAutoInc || autoIncColumnType?.nextValExpression != null

/**
* Represents the SQL statement that inserts new rows into a table and returns specified data from the inserted rows.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.jetbrains.exposed.sql.statements

import org.jetbrains.exposed.sql.AbstractQuery
import org.jetbrains.exposed.sql.Column
import org.jetbrains.exposed.sql.Table
import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.vendors.H2Dialect
Expand All @@ -24,3 +26,25 @@ open class ReplaceStatement<Key : Any>(table: Table) : InsertStatement<Key>(tabl
return functionProvider.replace(table, values.unzip().first, valuesSql, transaction, prepared)
}
}

/**
* Represents the SQL statement that uses data retrieved from a [selectQuery] to either insert a new row into a table,
* or, if insertion would violate a unique constraint, first delete the existing row before inserting a new row.
*
* @param columns Columns to either insert values into or delete values from then insert into.
* @param selectQuery Source SELECT query that provides the values to insert.
*/
open class ReplaceSelectStatement(
columns: List<Column<*>>,
selectQuery: AbstractQuery<*>
) : InsertSelectStatement(columns, selectQuery) {
override fun prepareSQL(transaction: Transaction, prepared: Boolean): String {
val querySql = selectQuery.prepareSQL(transaction, prepared)
val dialect = transaction.db.dialect
val functionProvider = when (dialect.h2Mode) {
H2Dialect.H2CompatibilityMode.MySQL, H2Dialect.H2CompatibilityMode.MariaDB -> MysqlFunctionProvider()
else -> dialect.functionProvider
}
return functionProvider.replace(targets.single(), columns, querySql, transaction, prepared)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,42 +63,33 @@ class SequencesTests : DatabaseTestsBase() {
}

@Test
fun `testInsertWithCustomSequence`() {
val customSequence = Sequence(
name = "my_sequence",
startWith = 4,
incrementBy = 2,
minValue = 1,
maxValue = 100,
cycle = true,
cache = 20
)
fun testInsertWithCustomSequence() {
val tester = object : Table("tester") {
val id = integer("id").autoIncrement(customSequence)
var name = varchar("name", 25)
val id = integer("id").autoIncrement(myseq)
val name = varchar("name", 25)

override val primaryKey = PrimaryKey(id, name)
}
withDb {
if (currentDialectTest.supportsSequenceAsGeneratedKeys) {
try {
SchemaUtils.create(tester)
assertTrue(customSequence.exists())
assertTrue(myseq.exists())

var testerId = tester.insert {
it[name] = "Hichem"
} get tester.id

assertEquals(customSequence.startWith, testerId.toLong())
assertEquals(myseq.startWith, testerId.toLong())

testerId = tester.insert {
it[name] = "Andrey"
} get tester.id

assertEquals(customSequence.startWith!! + customSequence.incrementBy!!, testerId.toLong())
assertEquals(myseq.startWith!! + myseq.incrementBy!!, testerId.toLong())
} finally {
SchemaUtils.drop(tester)
assertFalse(customSequence.exists())
assertFalse(myseq.exists())
}
}
}
Expand Down Expand Up @@ -131,42 +122,33 @@ class SequencesTests : DatabaseTestsBase() {
}

@Test
fun `testInsertInIdTableWithCustomSequence`() {
val customSequence = Sequence(
name = "my_sequence",
startWith = 4,
incrementBy = 2,
minValue = 1,
maxValue = 100,
cycle = true,
cache = 20
)
fun testInsertInIdTableWithCustomSequence() {
val tester = object : IdTable<Long>("tester") {
override val id = long("id").autoIncrement(customSequence).entityId()
var name = varchar("name", 25)
override val id = long("id").autoIncrement(myseq).entityId()
val name = varchar("name", 25)

override val primaryKey = PrimaryKey(id, name)
}
withDb {
if (currentDialectTest.supportsSequenceAsGeneratedKeys) {
try {
SchemaUtils.create(tester)
assertTrue(customSequence.exists())
assertTrue(myseq.exists())

var testerId = tester.insert {
it[name] = "Hichem"
} get tester.id

assertEquals(customSequence.startWith, testerId.value)
assertEquals(myseq.startWith, testerId.value)

testerId = tester.insert {
it[name] = "Andrey"
} get tester.id

assertEquals(customSequence.startWith!! + customSequence.incrementBy!!, testerId.value)
assertEquals(myseq.startWith!! + myseq.incrementBy!!, testerId.value)
} finally {
SchemaUtils.drop(tester)
assertFalse(customSequence.exists())
assertFalse(myseq.exists())
}
}
}
Expand Down Expand Up @@ -239,17 +221,8 @@ class SequencesTests : DatabaseTestsBase() {

@Test
fun testExistingSequencesForAutoIncrementWithCustomSequence() {
val customSequence = Sequence(
name = "my_sequence",
startWith = 4,
incrementBy = 2,
minValue = 1,
maxValue = 100,
cycle = true,
cache = 20
)
val tableWithExplicitSequenceName = object : IdTable<Long>() {
override val id: Column<EntityID<Long>> = long("id").autoIncrement(customSequence).entityId()
override val id: Column<EntityID<Long>> = long("id").autoIncrement(myseq).entityId()
}

withDb {
Expand All @@ -260,7 +233,7 @@ class SequencesTests : DatabaseTestsBase() {
val sequences = currentDialectTest.sequences()

assertTrue(sequences.isNotEmpty())
assertTrue(sequences.any { it == customSequence.name.inProperCase() })
assertTrue(sequences.any { it == myseq.name.inProperCase() })
} finally {
SchemaUtils.drop(tableWithExplicitSequenceName)
}
Expand Down Expand Up @@ -340,18 +313,18 @@ class SequencesTests : DatabaseTestsBase() {

private object Developer : Table() {
val id = integer("id")
var name = varchar("name", 25)
val name = varchar("name", 25)

override val primaryKey = PrimaryKey(id, name)
}

private object DeveloperWithLongId : LongIdTable() {
var name = varchar("name", 25)
val name = varchar("name", 25)
}

private object DeveloperWithAutoIncrementBySequence : IdTable<Long>() {
override val id: Column<EntityID<Long>> = long("id").autoIncrement("id_seq").entityId()
var name = varchar("name", 25)
val name = varchar("name", 25)
}

private val myseq = Sequence(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.jetbrains.exposed.sql.tests.shared.assertTrue
import org.junit.Test
import java.util.*
import kotlin.test.assertContentEquals

class ReplaceTests : DatabaseTestsBase() {

Expand All @@ -22,6 +23,58 @@ class ReplaceTests : DatabaseTestsBase() {
override val primaryKey = PrimaryKey(username)
}

@Test
fun testReplaceSelect() {
withTables(replaceNotSupported, NewAuth) { testDb ->
NewAuth.batchReplace(listOf("username1", "username2")) { // inserts 2 new non-conflict rows with defaults
this[NewAuth.username] = it
this[NewAuth.session] = "session".toByteArray()
}

val result1 = NewAuth.selectAll().toList()
assertTrue(result1.all { it[NewAuth.timestamp] == 0L })
assertTrue(result1.all { it[NewAuth.serverID].isEmpty() })

val timeNow = System.currentTimeMillis()
val specialId = "special server id"
val allRowsWithNewDefaults = NewAuth.select(NewAuth.username, NewAuth.session, longLiteral(timeNow), stringLiteral(specialId))

val affectedRowCount = NewAuth.replace(allRowsWithNewDefaults)
// MySQL returns 1 for every insert + 1 for every delete on conflict, while others only count inserts
val expectedRowCount = if (testDb in TestDB.ALL_MYSQL_LIKE) 4 else 2
assertEquals(expectedRowCount, affectedRowCount)

val result2 = NewAuth.selectAll().toList()
assertTrue(result2.all { it[NewAuth.timestamp] == timeNow })
assertTrue(result2.all { it[NewAuth.serverID] == specialId })
}
}

@Test
fun testReplaceSelectWithSpecificColumns() {
withTables(replaceNotSupported, NewAuth) { testDb ->
val (name1, name2, oldSession) = Triple("username1", "username2", "session1".toByteArray())
NewAuth.batchReplace(listOf(name1, name2)) { // inserts 2 new non-conflict rows with defaults
this[NewAuth.username] = it
this[NewAuth.session] = oldSession
}

val newSession = "session2"
val name1Row = NewAuth.select(NewAuth.username, stringLiteral(newSession)).where { NewAuth.username eq name1 }

val affectedRowCount = NewAuth.replace(name1Row, columns = listOf(NewAuth.username, NewAuth.session))
// MySQL returns 1 for every insert + 1 for every delete on conflict, while others only count inserts
val expectedRowCount = if (testDb in TestDB.ALL_MYSQL_LIKE) 2 else 1
assertEquals(expectedRowCount, affectedRowCount)

val name1Result = NewAuth.selectAll().where { NewAuth.username eq name1 }.single()
assertContentEquals(newSession.toByteArray(), name1Result[NewAuth.session])

val name2Result = NewAuth.selectAll().where { NewAuth.username eq name2 }.single()
assertContentEquals(oldSession, name2Result[NewAuth.session])
}
}

@Test
fun testReplaceWithPKConflict() {
withTables(replaceNotSupported, NewAuth) {
Expand Down

0 comments on commit f2670fc

Please sign in to comment.