From e5f628a93bcf673df9cc60ad878366b3269d1861 Mon Sep 17 00:00:00 2001 From: Tapac Date: Sun, 3 Nov 2019 21:16:29 +0900 Subject: [PATCH] IllegalArgumentException thrown when List>.awaitAll() is used on List> #658 --- .../transactions/experimental/Suspended.kt | 76 +++++++------------ .../org/jetbrains/exposed/sql/vendors/H2.kt | 7 +- .../sql/tests/shared/CoroutineTests.kt | 37 ++++++++- 3 files changed, 67 insertions(+), 53 deletions(-) diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/transactions/experimental/Suspended.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/transactions/experimental/Suspended.kt index 06f6be48ce..7d09f2cc78 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/transactions/experimental/Suspended.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/transactions/experimental/Suspended.kt @@ -54,57 +54,37 @@ suspend fun Transaction.suspendedTransaction(context: CoroutineDispatcher? = suspendedTransactionAsyncInternal(false, statement).await() } -class TransactionResult(internal val transaction: Transaction, - internal val deferred: Deferred, - internal var shouldCommit: Boolean) : Deferred by deferred { - override suspend fun await(): T { - return deferred.await().apply { - if (shouldCommit) { - val currentTransaction = TransactionManager.currentOrNull() - try { - val temporaryManager = transaction.db.transactionManager - (temporaryManager as? ThreadLocalTransactionManager)?.threadLocal?.set(transaction) - TransactionManager.resetCurrent(temporaryManager) - with(transaction) { - try { - commit() - try { - currentStatement?.let { - it.closeIfPossible() - currentStatement = null - } - closeExecutedStatements() - } catch (e: Exception) { - exposedLogger.warn("Statements close failed", e) - } - closeLoggingException { exposedLogger.warn("Transaction close failed: ${it.message}. Statement: $currentStatement", it) } - } catch (e: Exception) { - rollbackLoggingException { exposedLogger.warn("Transaction rollback failed: ${it.message}. Statement: $currentStatement", it) } - throw e - } - } - } finally { - val transactionManager = currentTransaction?.db?.transactionManager - (transactionManager as? ThreadLocalTransactionManager)?.threadLocal?.set(currentTransaction) - TransactionManager.resetCurrent(transactionManager) +private fun Transaction.commitInAsync() { + val currentTransaction = TransactionManager.currentOrNull() + try { + val temporaryManager = this.db.transactionManager + (temporaryManager as? ThreadLocalTransactionManager)?.threadLocal?.set(this) + TransactionManager.resetCurrent(temporaryManager) + try { + commit() + try { + currentStatement?.let { + it.closeIfPossible() + currentStatement = null } + closeExecutedStatements() + } catch (e: Exception) { + exposedLogger.warn("Statements close failed", e) } + closeLoggingException { exposedLogger.warn("Transaction close failed: ${it.message}. Statement: $currentStatement", it) } + } catch (e: Exception) { + rollbackLoggingException { exposedLogger.warn("Transaction rollback failed: ${it.message}. Statement: $currentStatement", it) } + throw e } - } -} - -suspend fun TransactionResult.andThen(statement: suspend Transaction.(T) -> R) : TransactionResult { - val currentAsync = this - return withTransactionScope(null, currentAsync.transaction, null) { - currentAsync.shouldCommit = false - suspendedTransactionAsyncInternal(true) { - statement(currentAsync.await()) - } + } finally { + val transactionManager = currentTransaction?.db?.transactionManager + (transactionManager as? ThreadLocalTransactionManager)?.threadLocal?.set(currentTransaction) + TransactionManager.resetCurrent(transactionManager) } } suspend fun suspendedTransactionAsync(context: CoroutineDispatcher? = null, db: Database? = null, - statement: suspend Transaction.() -> T) : TransactionResult { + statement: suspend Transaction.() -> T) : Deferred { val currentTransaction = TransactionManager.currentOrNull() return withTransactionScope(context, null, db) { suspendedTransactionAsyncInternal(currentTransaction != tx, statement) @@ -137,12 +117,14 @@ private suspend fun withTransactionScope(context: CoroutineContext?, } private fun TransactionScope.suspendedTransactionAsyncInternal(shouldCommit: Boolean, - statement: suspend Transaction.() -> T) : TransactionResult - = TransactionResult(tx, async { + statement: suspend Transaction.() -> T) : Deferred + = async { try { tx.statement() } catch (e: Throwable) { tx.rollbackLoggingException { exposedLogger.warn("Transaction rollback failed: ${it.message}. Statement: ${tx.currentStatement}", it) } throw e + } finally { + if (shouldCommit) tx.commitInAsync() } - }, shouldCommit) \ No newline at end of file + } \ No newline at end of file diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/H2.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/H2.kt index d93081923c..12403cc3d0 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/H2.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/H2.kt @@ -13,7 +13,10 @@ internal object H2DataTypeProvider : DataTypeProvider() { } private val Transaction.isMySQLMode: Boolean - get() = (connection.connection as? JdbcConnection)?.settings?.mode?.enum == Mode.ModeEnum.MySQL + get() = (connection.connection as? JdbcConnection)?.let { + !it.isClosed && it.settings.mode.enum == Mode.ModeEnum.MySQL + } == true + internal object H2FunctionProvider : FunctionProvider() { @@ -74,7 +77,7 @@ open class H2Dialect : VendorDialect(dialectName, H2DataTypeProvider, H2Function override val name: String get() = when { - TransactionManager.current().isMySQLMode -> "$dialectName (Mysql Mode)" + TransactionManager.currentOrNull()?.isMySQLMode == true -> "$dialectName (Mysql Mode)" else -> dialectName } companion object { diff --git a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/CoroutineTests.kt b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/CoroutineTests.kt index c9643531b6..78af7e6473 100644 --- a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/CoroutineTests.kt +++ b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/CoroutineTests.kt @@ -5,10 +5,10 @@ import kotlinx.coroutines.debug.junit4.CoroutinesTimeout import org.jetbrains.exposed.sql.* import org.jetbrains.exposed.sql.tests.DatabaseTestsBase import org.jetbrains.exposed.sql.tests.TestDB -import org.jetbrains.exposed.sql.transactions.experimental.andThen import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransaction import org.jetbrains.exposed.sql.transactions.experimental.suspendedTransaction import org.jetbrains.exposed.sql.transactions.experimental.suspendedTransactionAsync +import org.jetbrains.exposed.sql.transactions.transaction import org.jetbrains.exposed.test.utils.RepeatableTest import org.junit.Rule import org.junit.Test @@ -68,12 +68,14 @@ class CoroutineTests : DatabaseTestsBase() { launchResult.await() val result = suspendedTransactionAsync(Dispatchers.Default, db = db) { Testing.select { Testing.id.eq(1) }.single()[Testing.id] - }.andThen { - assertEquals(1, it) + }.await() + + val result2 = suspendedTransactionAsync(Dispatchers.Default, db = db) { + assertEquals(1, result) Testing.selectAll().count() } - kotlin.test.assertEquals(1, result.await()) + kotlin.test.assertEquals(1, result2.await()) } while (!job.isCompleted) Thread.sleep(100) @@ -113,4 +115,31 @@ class CoroutineTests : DatabaseTestsBase() { mainJob.getCompletionExceptionOrNull()?.let { throw it } } } + + @Test @RepeatableTest(10) + fun awaitAllTest() { + suspend fun insertTesting(db: Database) = newSuspendedTransaction(db = db) { + Testing.insert {} + } + withTables(listOf(TestDB.SQLITE), Testing) { + val mainJob = GlobalScope.async { + + val results = (1..5).map { indx -> + suspendedTransactionAsync(Dispatchers.IO, db = db) { + Testing.insert { } + indx + } + }.awaitAll() + + kotlin.test.assertEquals(15, results.sum()) + } + + while (!mainJob.isCompleted) Thread.sleep(100) + mainJob.getCompletionExceptionOrNull()?.let { throw it } + + transaction { + assertEquals(5, Testing.selectAll().count()) + } + } + } } \ No newline at end of file