Skip to content

Commit

Permalink
IllegalArgumentException thrown when List<Deferred<T>>.awaitAll() is …
Browse files Browse the repository at this point in the history
…used on List<TransactionResult<T>> #658
  • Loading branch information
Tapac committed Nov 3, 2019
1 parent 2585ae6 commit e5f628a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,57 +54,37 @@ suspend fun <T> Transaction.suspendedTransaction(context: CoroutineDispatcher? =
suspendedTransactionAsyncInternal(false, statement).await()
}

class TransactionResult<T>(internal val transaction: Transaction,
internal val deferred: Deferred<T>,
internal var shouldCommit: Boolean) : Deferred<T> by deferred {
override suspend fun await(): T {
return deferred.await().apply {
if (shouldCommit) {
val currentTransaction = TransactionManager.currentOrNull()
try {
val temporaryManager = transaction.db.transactionManager
(temporaryManager as? ThreadLocalTransactionManager)?.threadLocal?.set(transaction)
TransactionManager.resetCurrent(temporaryManager)
with(transaction) {
try {
commit()
try {
currentStatement?.let {
it.closeIfPossible()
currentStatement = null
}
closeExecutedStatements()
} catch (e: Exception) {
exposedLogger.warn("Statements close failed", e)
}
closeLoggingException { exposedLogger.warn("Transaction close failed: ${it.message}. Statement: $currentStatement", it) }
} catch (e: Exception) {
rollbackLoggingException { exposedLogger.warn("Transaction rollback failed: ${it.message}. Statement: $currentStatement", it) }
throw e
}
}
} finally {
val transactionManager = currentTransaction?.db?.transactionManager
(transactionManager as? ThreadLocalTransactionManager)?.threadLocal?.set(currentTransaction)
TransactionManager.resetCurrent(transactionManager)
private fun Transaction.commitInAsync() {
val currentTransaction = TransactionManager.currentOrNull()
try {
val temporaryManager = this.db.transactionManager
(temporaryManager as? ThreadLocalTransactionManager)?.threadLocal?.set(this)
TransactionManager.resetCurrent(temporaryManager)
try {
commit()
try {
currentStatement?.let {
it.closeIfPossible()
currentStatement = null
}
closeExecutedStatements()
} catch (e: Exception) {
exposedLogger.warn("Statements close failed", e)
}
closeLoggingException { exposedLogger.warn("Transaction close failed: ${it.message}. Statement: $currentStatement", it) }
} catch (e: Exception) {
rollbackLoggingException { exposedLogger.warn("Transaction rollback failed: ${it.message}. Statement: $currentStatement", it) }
throw e
}
}
}

suspend fun <T, R> TransactionResult<T>.andThen(statement: suspend Transaction.(T) -> R) : TransactionResult<R> {
val currentAsync = this
return withTransactionScope(null, currentAsync.transaction, null) {
currentAsync.shouldCommit = false
suspendedTransactionAsyncInternal(true) {
statement(currentAsync.await())
}
} finally {
val transactionManager = currentTransaction?.db?.transactionManager
(transactionManager as? ThreadLocalTransactionManager)?.threadLocal?.set(currentTransaction)
TransactionManager.resetCurrent(transactionManager)
}
}

suspend fun <T> suspendedTransactionAsync(context: CoroutineDispatcher? = null, db: Database? = null,
statement: suspend Transaction.() -> T) : TransactionResult<T> {
statement: suspend Transaction.() -> T) : Deferred<T> {
val currentTransaction = TransactionManager.currentOrNull()
return withTransactionScope(context, null, db) {
suspendedTransactionAsyncInternal(currentTransaction != tx, statement)
Expand Down Expand Up @@ -137,12 +117,14 @@ private suspend fun <T> withTransactionScope(context: CoroutineContext?,
}

private fun <T> TransactionScope.suspendedTransactionAsyncInternal(shouldCommit: Boolean,
statement: suspend Transaction.() -> T) : TransactionResult<T>
= TransactionResult(tx, async {
statement: suspend Transaction.() -> T) : Deferred<T>
= async {
try {
tx.statement()
} catch (e: Throwable) {
tx.rollbackLoggingException { exposedLogger.warn("Transaction rollback failed: ${it.message}. Statement: ${tx.currentStatement}", it) }
throw e
} finally {
if (shouldCommit) tx.commitInAsync()
}
}, shouldCommit)
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ internal object H2DataTypeProvider : DataTypeProvider() {
}

private val Transaction.isMySQLMode: Boolean
get() = (connection.connection as? JdbcConnection)?.settings?.mode?.enum == Mode.ModeEnum.MySQL
get() = (connection.connection as? JdbcConnection)?.let {
!it.isClosed && it.settings.mode.enum == Mode.ModeEnum.MySQL
} == true


internal object H2FunctionProvider : FunctionProvider() {

Expand Down Expand Up @@ -74,7 +77,7 @@ open class H2Dialect : VendorDialect(dialectName, H2DataTypeProvider, H2Function

override val name: String
get() = when {
TransactionManager.current().isMySQLMode -> "$dialectName (Mysql Mode)"
TransactionManager.currentOrNull()?.isMySQLMode == true -> "$dialectName (Mysql Mode)"
else -> dialectName
}
companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import kotlinx.coroutines.debug.junit4.CoroutinesTimeout
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.TestDB
import org.jetbrains.exposed.sql.transactions.experimental.andThen
import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransaction
import org.jetbrains.exposed.sql.transactions.experimental.suspendedTransaction
import org.jetbrains.exposed.sql.transactions.experimental.suspendedTransactionAsync
import org.jetbrains.exposed.sql.transactions.transaction
import org.jetbrains.exposed.test.utils.RepeatableTest
import org.junit.Rule
import org.junit.Test
Expand Down Expand Up @@ -68,12 +68,14 @@ class CoroutineTests : DatabaseTestsBase() {
launchResult.await()
val result = suspendedTransactionAsync(Dispatchers.Default, db = db) {
Testing.select { Testing.id.eq(1) }.single()[Testing.id]
}.andThen {
assertEquals(1, it)
}.await()

val result2 = suspendedTransactionAsync(Dispatchers.Default, db = db) {
assertEquals(1, result)
Testing.selectAll().count()
}

kotlin.test.assertEquals(1, result.await())
kotlin.test.assertEquals(1, result2.await())
}

while (!job.isCompleted) Thread.sleep(100)
Expand Down Expand Up @@ -113,4 +115,31 @@ class CoroutineTests : DatabaseTestsBase() {
mainJob.getCompletionExceptionOrNull()?.let { throw it }
}
}

@Test @RepeatableTest(10)
fun awaitAllTest() {
suspend fun insertTesting(db: Database) = newSuspendedTransaction(db = db) {
Testing.insert {}
}
withTables(listOf(TestDB.SQLITE), Testing) {
val mainJob = GlobalScope.async {

val results = (1..5).map { indx ->
suspendedTransactionAsync(Dispatchers.IO, db = db) {
Testing.insert { }
indx
}
}.awaitAll()

kotlin.test.assertEquals(15, results.sum())
}

while (!mainJob.isCompleted) Thread.sleep(100)
mainJob.getCompletionExceptionOrNull()?.let { throw it }

transaction {
assertEquals(5, Testing.selectAll().count())
}
}
}
}

0 comments on commit e5f628a

Please sign in to comment.