Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: EXPOSED-57 BatchInsertStatement can't be used with MySQL upsert #1754

Merged
merged 3 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@ package org.jetbrains.exposed.sql.statements

import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.statements.api.PreparedStatementApi
import org.jetbrains.exposed.sql.vendors.PostgreSQLDialect
import org.jetbrains.exposed.sql.vendors.currentDialect
import org.jetbrains.exposed.sql.vendors.*
import org.jetbrains.exposed.sql.vendors.inProperCase
import java.sql.ResultSet
import java.sql.SQLException
import kotlin.properties.Delegates

open class InsertStatement<Key : Any>(val table: Table, val isIgnore: Boolean = false) : UpdateBuilder<Int>(StatementType.INSERT, listOf(table)) {

/**
* Returns the number of rows affected by the insert operation.
*
* When returned by a `BatchInsertStatement` or `BatchUpsertStatement`, the returned value is calculated using the
* sum of the individual values generated by each statement.
*
* **Note**: Some vendors support returning the affected-row value of 2 if an existing row is updated by an upsert
* operation; please check the documentation.
*/
var insertedCount: Int by Delegates.notNull()

var resultedValues: List<ResultRow>? = null
Expand Down Expand Up @@ -62,12 +70,10 @@ open class InsertStatement<Key : Any>(val table: Table, val isIgnore: Boolean =
}
}

/** TODO: https://github.com/JetBrains/Exposed/issues/129
* doesn't work with MySQL `INSERT ... ON DUPLICATE UPDATE`
*/
// assert(isIgnore || autoGeneratedKeys.isEmpty() || autoGeneratedKeys.size == inserted) {
// "Number of autoincs (${autoGeneratedKeys.size}) doesn't match number of batch entries ($inserted)"
// }
assert(isIgnore || autoGeneratedKeys.isEmpty() ||
autoGeneratedKeys.size == inserted || currentDialect.supportsTernaryAffectedRowValues) {
"Number of autoincs (${autoGeneratedKeys.size}) doesn't match number of batch entries ($inserted)"
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,9 @@ interface DatabaseDialect {
val supportsSequenceAsGeneratedKeys: Boolean get() = supportsCreateSequence
val supportsOnlyIdentifiersInGeneratedKeys: Boolean get() = false

/** Returns `true` if the dialect supports an upsert operation returning an affected-row value of 0, 1, or 2. */
val supportsTernaryAffectedRowValues: Boolean get() = false

/** Returns`true` if the dialect supports schema creation. */
val supportsCreateSchema: Boolean get() = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ open class H2Dialect : VendorDialect(dialectName, H2DataTypeProvider, H2Function
override val supportsSequenceAsGeneratedKeys: Boolean by lazy {
resolveDelegatedDialect()?.supportsSequenceAsGeneratedKeys ?: super.supportsSequenceAsGeneratedKeys
}
override val supportsTernaryAffectedRowValues: Boolean by lazy {
resolveDelegatedDialect()?.supportsTernaryAffectedRowValues ?: super.supportsTernaryAffectedRowValues
}
override val supportsCreateSchema: Boolean by lazy { resolveDelegatedDialect()?.supportsCreateSchema ?: super.supportsCreateSchema }
override val supportsSubqueryUnions: Boolean by lazy { resolveDelegatedDialect()?.supportsSubqueryUnions ?: super.supportsSubqueryUnions }
override val supportsDualTableConcept: Boolean by lazy { resolveDelegatedDialect()?.supportsDualTableConcept ?: super.supportsDualTableConcept }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ open class MysqlDialect : VendorDialect(dialectName, MysqlDataTypeProvider, Mysq

override val supportsCreateSequence: Boolean = false

override val supportsTernaryAffectedRowValues: Boolean = true

override val supportsSubqueryUnions: Boolean = true

override val supportsOrderByNullsFirstLast: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import org.jetbrains.exposed.sql.transactions.inTopLevelTransaction
import org.jetbrains.exposed.sql.transactions.nullableTransactionScope
import org.jetbrains.exposed.sql.transactions.transaction
import org.jetbrains.exposed.sql.transactions.transactionManager
import org.jetbrains.exposed.sql.vendors.H2Dialect
import org.junit.Assume
import org.junit.AssumptionViolatedException
import org.testcontainers.containers.MySQLContainer
import org.testcontainers.containers.PostgreSQLContainer
import java.math.BigDecimal
import java.sql.Connection
import java.sql.SQLException
import java.time.Duration
Expand Down Expand Up @@ -286,7 +286,7 @@ abstract class DatabaseTestsBase {
}

fun Transaction.excludingH2Version1(dbSettings: TestDB, statement: Transaction.(TestDB) -> Unit) {
if (dbSettings !in TestDB.allH2TestDB || db.isVersionCovers(BigDecimal("2.0"))) {
if (dbSettings !in TestDB.allH2TestDB || (db.dialect as H2Dialect).isSecondVersion) {
statement(dbSettings)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import org.jetbrains.exposed.sql.SqlExpressionBuilder.concat
import org.jetbrains.exposed.sql.SqlExpressionBuilder.less
import org.jetbrains.exposed.sql.SqlExpressionBuilder.minus
import org.jetbrains.exposed.sql.SqlExpressionBuilder.plus
import org.jetbrains.exposed.sql.statements.BatchUpsertStatement
import org.jetbrains.exposed.sql.tests.*
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.jetbrains.exposed.sql.tests.shared.expectException
import org.junit.Test
import java.util.*
import kotlin.properties.Delegates

// Upsert implementation does not support H2 version 1
// https://youtrack.jetbrains.com/issue/EXPOSED-30/Phase-Out-Support-for-H2-Version-1.x
Expand All @@ -21,31 +23,24 @@ class UpsertTests : DatabaseTestsBase() {

@Test
fun testUpsertWithPKConflict() {
val tester = object : Table("tester") {
val id = integer("id").autoIncrement()
val name = varchar("name", 64)

override val primaryKey = PrimaryKey(id)
}

withTables(tester) { testDb ->
withTables(AutoIncTable) { testDb ->
excludingH2Version1(testDb) {
val id1 = tester.insert {
val id1 = AutoIncTable.insert {
it[name] = "A"
} get tester.id
} get AutoIncTable.id

tester.upsert {
AutoIncTable.upsert {
if (testDb in upsertViaMergeDB) it[id] = 2
it[name] = "B"
}
tester.upsert {
AutoIncTable.upsert {
it[id] = id1
it[name] = "C"
}

assertEquals(2, tester.selectAll().count())
val updatedResult = tester.select { tester.id eq id1 }.single()
assertEquals("C", updatedResult[tester.name])
assertEquals(2, AutoIncTable.selectAll().count())
val updatedResult = AutoIncTable.select { AutoIncTable.id eq id1 }.single()
assertEquals("C", updatedResult[AutoIncTable.name])
}
}
}
Expand Down Expand Up @@ -424,6 +419,56 @@ class UpsertTests : DatabaseTestsBase() {
}
}

@Test
fun testInsertedCountWithBatchUpsert() {
withTables(AutoIncTable) { testDb ->
excludingH2Version1(testDb) {
// SQL Server requires statements to be executed before results can be obtained
val isNotSqlServer = testDb != TestDB.SQLSERVER
val data = listOf(1 to "A", 2 to "B", 3 to "C")
val newDataSize = data.size
var statement: BatchUpsertStatement by Delegates.notNull()

// all new rows inserted
AutoIncTable.batchUpsert(data, shouldReturnGeneratedValues = isNotSqlServer) { (id, name) ->
statement = this
this[AutoIncTable.id] = id
this[AutoIncTable.name] = name
}
assertEquals(newDataSize, statement.insertedCount)

// all existing rows set to their current values
val isH2MysqlMode = testDb == TestDB.H2_MYSQL || testDb == TestDB.H2_MARIADB
var expected = if (isH2MysqlMode) 0 else newDataSize
AutoIncTable.batchUpsert(data, shouldReturnGeneratedValues = isNotSqlServer) { (id, name) ->
statement = this
this[AutoIncTable.id] = id
this[AutoIncTable.name] = name
}
assertEquals(expected, statement.insertedCount)

// all existing rows updated & 1 new row inserted
val updatedData = data.map { it.first to "new${it.second}" } + (4 to "D")
expected = if (testDb in TestDB.mySqlRelatedDB) newDataSize * 2 + 1 else newDataSize + 1
AutoIncTable.batchUpsert(updatedData, shouldReturnGeneratedValues = isNotSqlServer) { (id, name) ->
statement = this
this[AutoIncTable.id] = id
this[AutoIncTable.name] = name
}
assertEquals(expected, statement.insertedCount)

assertEquals(updatedData.size.toLong(), AutoIncTable.selectAll().count())
}
}
}

private object AutoIncTable : Table("auto_inc_table") {
val id = integer("id").autoIncrement()
val name = varchar("name", 64)

override val primaryKey = PrimaryKey(id)
}

private object Words : Table("words") {
val word = varchar("name", 64).uniqueIndex()
val count = integer("count").default(1)
Expand Down