diff --git a/ktor-utils/common/src/io/ktor/util/pipeline/SuspendFunctionGun.kt b/ktor-utils/common/src/io/ktor/util/pipeline/SuspendFunctionGun.kt index 2ac171736eb..d8cbadfce62 100644 --- a/ktor-utils/common/src/io/ktor/util/pipeline/SuspendFunctionGun.kt +++ b/ktor-utils/common/src/io/ktor/util/pipeline/SuspendFunctionGun.kt @@ -18,7 +18,7 @@ internal class SuspendFunctionGun( // this is impossible to inline because of property name clash // between PipelineContext.context and Continuation.context - private val continuation: Continuation = object : Continuation, CoroutineStackFrame { + internal val continuation: Continuation = object : Continuation, CoroutineStackFrame { override val callerFrame: CoroutineStackFrame? get() = peekContinuation() as? CoroutineStackFrame var currentIndex: Int = Int.MIN_VALUE @@ -48,7 +48,18 @@ internal class SuspendFunctionGun( } override val context: CoroutineContext - get() = suspensions[lastSuspensionIndex]?.context ?: error("Not started") + get() { + val continuation = suspensions[lastSuspensionIndex] + if (continuation !== this && continuation != null) return continuation.context + + var index = lastSuspensionIndex - 1 + while (index >= 0) { + val cont = suspensions[index--] + if (cont !== this && cont != null) return cont.context + } + + error("Not started") + } override fun resumeWith(result: Result) { if (result.isFailure) { @@ -144,7 +155,7 @@ internal class SuspendFunctionGun( suspensions[lastSuspensionIndex--] = null } - private fun addContinuation(continuation: Continuation) { + internal fun addContinuation(continuation: Continuation) { suspensions[++lastSuspensionIndex] = continuation } } diff --git a/ktor-utils/common/test/io/ktor/util/SuspendFunctionGunTest.kt b/ktor-utils/common/test/io/ktor/util/SuspendFunctionGunTest.kt new file mode 100644 index 00000000000..062614d60d2 --- /dev/null +++ b/ktor-utils/common/test/io/ktor/util/SuspendFunctionGunTest.kt @@ -0,0 +1,40 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.util + +import io.ktor.test.dispatcher.* +import io.ktor.util.pipeline.* +import kotlin.coroutines.* +import kotlin.test.* + +class SuspendFunctionGunTest { + @Test + fun throwsErrorWhenNoDistinctContinuation() = testSuspend { + val gun = SuspendFunctionGun(Unit, Unit, listOf({ _, _, _ -> })) + gun.addContinuation(gun.continuation) + + val cause = assertFailsWith { gun.continuation.context } + assertEquals("Not started", cause.message) + } + + @Test + fun returnsLastDistinctContinuationContext() = testSuspend { + val gun = SuspendFunctionGun(Unit, Unit, listOf({ _, _, _ -> }, { _, _, _ -> })) + val continuation = Continuation(EmptyCoroutineContext) {} + gun.addContinuation(continuation) + gun.addContinuation(gun.continuation) + + assertEquals(gun.continuation.context, continuation.context) + } + + @Test + fun returnsFirstDistinctContinuationContext() = testSuspend { + val gun = SuspendFunctionGun(Unit, Unit, listOf({ _, _, _ -> })) + val continuation = Continuation(EmptyCoroutineContext) {} + gun.addContinuation(continuation) + + assertEquals(gun.continuation.context, continuation.context) + } +}