Skip to content

Commit

Permalink
stdlib: DeepRecursiveFunction KT-31741
Browse files Browse the repository at this point in the history
Introduces coroutine-based framework to execute deeply recursive
functions that utilities the heap and thus avoid StackOverflowError.

Fixes KT-31741
  • Loading branch information
elizarov committed Apr 28, 2020
1 parent 151890d commit 0a97c6c
Show file tree
Hide file tree
Showing 6 changed files with 389 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ public actual inline fun <R, T> (suspend R.() -> T).startCoroutineUninterceptedO
completion: Continuation<T>
): Any? = this.asDynamic()(receiver, completion, false)

@InlineOnly
internal actual inline fun <R, P, T> (suspend R.(P) -> T).startCoroutineUninterceptedOrReturn(
receiver: R,
param: P,
completion: Continuation<T>
): Any? = this.asDynamic()(receiver, param, completion, false)

/**
* Creates unintercepted coroutine without receiver and with result type [T].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ public actual inline fun <R, T> (suspend R.() -> T).startCoroutineUninterceptedO
completion: Continuation<T>
): Any? = (this as Function2<R, Continuation<T>, Any?>).invoke(receiver, completion)

@InlineOnly
internal actual inline fun <R, P, T> (suspend R.(P) -> T).startCoroutineUninterceptedOrReturn(
receiver: R,
param: P,
completion: Continuation<T>
): Any? = (this as Function3<R, P, Continuation<T>, Any?>).invoke(receiver, param, completion)

// JVM declarations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ public expect inline fun <R, T> (suspend R.() -> T).startCoroutineUninterceptedO
completion: Continuation<T>
): Any?

// Internal version that support arity-2 suspending functions, might be made public in the future if needed
internal expect inline fun <R, P, T> (suspend R.(P) -> T).startCoroutineUninterceptedOrReturn(
receiver: R,
param: P,
completion: Continuation<T>
): Any?

@SinceKotlin("1.3")
public expect fun <T> (suspend () -> T).createCoroutineUnintercepted(
completion: Continuation<T>
Expand Down
213 changes: 213 additions & 0 deletions libraries/stdlib/src/kotlin/util/DeepRecursive.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
/*
* Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

package kotlin

import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*

/**
* Defines deep recursive function that keeps its stack on the heap,
* which allows very deep recursive computations that do not use the actual call stack.
* To initiate a call to this deep recursive function use its [invoke] function.
* As a rule of thumb, it should be used if recursion goes deeper than a thousand calls.
*
* The [DeepRecursiveFunction] takes one parameter of type [T] and returns a result of type [R].
* The [block] of code defines the body of a recursive function. In this block
* [callRecursive][DeepRecursiveScope.callRecursive] function can be used to make a recursive call
* to the declared function. Other instances of [DeepRecursiveFunction] can be called
* in this scope with `callRecursive` extension, too.
*
* For example, take a look at the following recursive tree class and a deeply
* recursive instance of this tree with 100K nodes:
*
* ```
* class Tree(val left: Tree? = null, val right: Tree? = null)
* val deepTree = generateSequence(Tree()) { Tree(it) }.take(100_000).last()
* ```
*
* A regular recursive function can be defined to compute a depth of a tree:
*
* ```
* fun depth(t: Tree?): Int =
* if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1
* println(depth(deepTree)) // StackOverflowError
* ```
*
* If this `depth` function is called for a `deepTree` it produces [StackOverflowError] because of deep recursion.
* However, the `depth` function can be rewritten using `DeepRecursiveFunction` in the following way, and then
* it successfully computes [`depth(deepTree)`][DeepRecursiveFunction.invoke] expression:
*
* ```
* val depth = DeepRecursiveFunction<Tree?, Int> { t ->
* if (t == null) 0 else max(callRecursive(t.left), callRecursive(t.right)) + 1
* }
* println(depth(deepTree)) // Ok
* ```
*
* Deep recursive functions can also mutually call each other using a heap for the stack via
* [callRecursive][DeepRecursiveScope.callRecursive] extension. For example, the
* following pair of mutually recursive functions computes the number of tree nodes at even depth in the tree.
*
* ```
* val mutualRecursion = object {
* val even: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
* if (t == null) 0 else odd.callRecursive(t.left) + odd.callRecursive(t.right) + 1
* }
* val odd: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
* if (t == null) 0 else even.callRecursive(t.left) + even.callRecursive(t.right)
* }
* }
* ```
*
* @param [T] the function parameter type.
* @param [R] the function result type.
* @param block the function body.
*/
@SinceKotlin("1.4")
public class DeepRecursiveFunction<T, R>(
internal val block: suspend DeepRecursiveScope<T, R>.(T) -> R
)

/**
* Initiates a call to this deep recursive function, forming a root of the call tree.
*
* This operator should not be used from inside of [DeepRecursiveScope] as it uses the call stack slot for
* initial recursive invocation. From inside of [DeepRecursiveScope] use
* [callRecursive][DeepRecursiveScope.callRecursive].
*/
@SinceKotlin("1.4")
public operator fun <T, R> DeepRecursiveFunction<T, R>.invoke(value: T): R =
DeepRecursiveScopeImpl<T, R>(block, value).runCallLoop()

/**
* A scope class for [DeepRecursiveFunction] function declaration that defines [callRecursive] methods to
* recursively call this function or another [DeepRecursiveFunction] putting the call activation frame on the heap.
*
* @param [T] function parameter type.
* @param [R] function result type.
*/
@RestrictsSuspension
@SinceKotlin("1.4")
public sealed class DeepRecursiveScope<T, R> {
/**
* Makes recursive call to this [DeepRecursiveFunction] function putting the call activation frame on the heap,
* as opposed to the actual call stack that is used by a regular recursive call.
*/
public abstract suspend fun callRecursive(value: T): R

/**
* Makes call to the specified [DeepRecursiveFunction] function putting the call activation frame on the heap,
* as opposed to the actual call stack that is used by a regular call.
*/
public abstract suspend fun <U, S> DeepRecursiveFunction<U, S>.callRecursive(value: U): S

@Deprecated(
level = DeprecationLevel.ERROR,
message =
"'invoke' should not be called from DeepRecursiveScope. " +
"Use 'callRecursive' to do recursion in the heap instead of the call stack.",
replaceWith = ReplaceWith("this.callRecursive(value)")
)
@Suppress("UNUSED_PARAMETER")
public operator fun DeepRecursiveFunction<*, *>.invoke(value: Any?): Nothing =
throw UnsupportedOperationException("Should not be called from DeepRecursiveScope")
}

// ================== Implementation ==================

private typealias DeepRecursiveFunctionBlock = suspend DeepRecursiveScope<*, *>.(Any?) -> Any?

private val UNDEFINED_RESULT = Result.success(COROUTINE_SUSPENDED)

@Suppress("UNCHECKED_CAST")
private class DeepRecursiveScopeImpl<T, R>(
block: suspend DeepRecursiveScope<T, R>.(T) -> R,
value: T
) : DeepRecursiveScope<T, R>(), Continuation<R> {
// Active function block
private var function: DeepRecursiveFunctionBlock = block as DeepRecursiveFunctionBlock

// Value to call function with
private var value: Any? = value

// Continuation of the current call
private var cont: Continuation<Any?>? = this as Continuation<Any?>

// Completion result (completion of the whole call stack)
private var result: Result<Any?> = UNDEFINED_RESULT

override val context: CoroutineContext
get() = EmptyCoroutineContext

override fun resumeWith(result: Result<R>) {
this.cont = null
this.result = result
}

override suspend fun callRecursive(value: T): R = suspendCoroutineUninterceptedOrReturn { cont ->
// calling the same function that is currently active
this.cont = cont as Continuation<Any?>
this.value = value
COROUTINE_SUSPENDED
}

override suspend fun <U, S> DeepRecursiveFunction<U, S>.callRecursive(value: U): S = suspendCoroutineUninterceptedOrReturn { cont ->
// calling another recursive function
val function = block as DeepRecursiveFunctionBlock
with(this@DeepRecursiveScopeImpl) {
val currentFunction = this.function
if (function !== currentFunction) {
// calling a different function -- create a trampoline to restore function ref
this.function = function
this.cont = crossFunctionCompletion(currentFunction, cont as Continuation<Any?>)
} else {
// calling the same function -- direct
this.cont = cont as Continuation<Any?>
}
this.value = value
}
COROUTINE_SUSPENDED
}

private fun crossFunctionCompletion(
currentFunction: DeepRecursiveFunctionBlock,
cont: Continuation<Any?>
): Continuation<Any?> = Continuation(EmptyCoroutineContext) {
this.function = currentFunction
// When going back from a trampoline we cannot just call cont.resume (stack usage!)
// We delegate the cont.resumeWith(it) call to runCallLoop
this.cont = cont
this.result = it
}

@Suppress("UNCHECKED_CAST")
fun runCallLoop(): R {
while (true) {
// Note: cont is set to null in DeepRecursiveScopeImpl.resumeWith when the whole computation completes
val result = this.result
val cont = this.cont
?: return (result as Result<R>).getOrThrow() // done -- final result
// The order of comparison is important here for that case of rogue class with broken equals
if (UNDEFINED_RESULT == result) {
// call "function" with "value" using "cont" as completion
val r = try {
// This is block.startCoroutine(this, value, cont)
function.startCoroutineUninterceptedOrReturn(this, value, cont)
} catch (e: Throwable) {
cont.resumeWithException(e)
continue
}
// If the function returns without suspension -- calls its continuation immediately
if (r !== COROUTINE_SUSPENDED)
cont.resume(r as R)
} else {
// we returned from a crossFunctionCompletion trampoline -- call resume here
this.result = UNDEFINED_RESULT // reset result back
cont.resumeWith(result)
}
}
}
}
143 changes: 143 additions & 0 deletions libraries/stdlib/test/utils/DeepRecursiveTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

package test.utils

import kotlin.test.*

class DeepRecursiveTest {
@Test
fun testSimpleReturn() {
// just returns a value without any recursive calls
val ok = DeepRecursiveFunction<Int, String> { i -> "Ok$i" }
assertEquals("Ok42", ok(42))
}

@Test
fun testDeepTreeDepth() {
val n = 100_000
assertEquals(n, depth(deepTree(n)))
}

@Test
fun testBinaryTreeDepth() {
val k = 15
assertEquals(k, depth(binaryTree(k)))
}

private class MutualRec {
val even: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
if (t == null) 0 else odd.callRecursive(t.left) + odd.callRecursive(t.right) + 1
}

val odd: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t ->
if (t == null) 0 else even.callRecursive(t.left) + even.callRecursive(t.right)
}
}

@Test
fun testDeepTreeOddEvenNodesMutual() {
val n = 100_000
val dt = deepTree(n)
val rec = MutualRec()
assertEquals(n / 2, rec.even(dt))
assertEquals(n / 2, rec.odd(dt))
}

@Test
fun testBinaryTreeOddEvenNodesMutual() {
val k = 15
val bt = binaryTree(k)
val rec = MutualRec()
assertEquals(21845, rec.even(bt))
assertEquals(10922, rec.odd(bt))
}

private class MutualAndDirectMixRec {
val b: DeepRecursiveFunction<Int, String> = DeepRecursiveFunction { i -> "b$i" }

val a: DeepRecursiveFunction<Int, String> = DeepRecursiveFunction { i ->
when (i) {
// mix callRecursive calls to other function and in this context
0 -> b.callRecursive(1) + callRecursive(2) + aa().callRecursive(3)
else -> "a$i"
}
}

fun aa() = a
}

@Test
fun testMutualAndDirectMix() {
// mix of callRecursion on this scope and on other DRF
val rec = MutualAndDirectMixRec()
val s = rec.a.invoke(0)
assertEquals("b1a2a3", s)
}

private class EqualToAnythingClassRec {
var nullCount = 0

val a: DeepRecursiveFunction<Tree?, EqualToAnything> = DeepRecursiveFunction { t ->
if (t == null) EqualToAnything(nullCount++) else b.callRecursive(t.left)
}

val b: DeepRecursiveFunction<Tree?, EqualToAnything> = DeepRecursiveFunction { t ->
if (t == null) EqualToAnything(nullCount++) else a.callRecursive(t.left)
}
}

@Test
fun testEqualToAnythingClass() {
// Mutually recursive tail calls & broken equals
val rec = EqualToAnythingClassRec()
val result = rec.a.invoke(deepTree(100))
assertEquals(1, rec.nullCount)
assertEquals(0, result.i)
}

@Test
fun testBadClass() {
val compute = object {
val a: DeepRecursiveFunction<Bad, Bad> = DeepRecursiveFunction { v -> Bad(v.i + 1) }
val b: DeepRecursiveFunction<Bad, Bad> = DeepRecursiveFunction { v ->
when (v.i) {
0 -> callRecursive(Bad(1))
1 -> Bad(a.callRecursive(Bad(19)).i + callRecursive(Bad(2)).i)
2 -> Bad(a.callRecursive(Bad(20)).i + 1)
else -> error("Cannot happen")
}
}
}
assertEquals(42, compute.b(Bad(0)).i)
}

private class Tree(val left: Tree? = null, val right: Tree? = null)

private fun deepTree(n: Int) = generateSequence(Tree()) { Tree(it) }.take(n).last()

private fun binaryTree(k: Int): Tree? =
if (k == 0) null else Tree(binaryTree(k - 1), binaryTree(k - 1))

private val depth = DeepRecursiveFunction<Tree?, Int> { t ->
if (t == null) 0 else maxOf(
callRecursive(t.left),
callRecursive(t.right)
) + 1
}

// It is equals to any other class
private class EqualToAnything(val i: Int) {
override fun equals(other: Any?): Boolean = true
override fun toString(): String = "OK"
}

// Throws exception on all object methods
private class Bad(val i: Int) {
override fun equals(other: Any?): Boolean = error("BAD")
override fun hashCode(): Int = error("BAD")
override fun toString(): String = error("BAD")
}
}
Loading

0 comments on commit 0a97c6c

Please sign in to comment.