Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some clean up for Free (and friends). #1085

Merged
merged 7 commits into from
Jun 11, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 127 additions & 59 deletions free/src/main/scala/cats/free/Free.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,62 +6,15 @@ import scala.annotation.tailrec
import cats.data.Xor, Xor.{Left, Right}
import cats.arrow.FunctionK

object Free {
/**
* Return from the computation with the given value.
*/
private final case class Pure[S[_], A](a: A) extends Free[S, A]

/** Suspend the computation with the given suspension. */
private final case class Suspend[S[_], A](a: S[A]) extends Free[S, A]

/** Call a subroutine and continue with the given function. */
private final case class Gosub[S[_], B, C](c: Free[S, C], f: C => Free[S, B]) extends Free[S, B]

/**
* Suspend a value within a functor lifting it to a Free.
*/
def liftF[F[_], A](value: F[A]): Free[F, A] = Suspend(value)

/** Suspend the Free with the Applicative */
def suspend[F[_], A](value: => Free[F, A])(implicit F: Applicative[F]): Free[F, A] =
liftF(F.pure(())).flatMap(_ => value)

/** Lift a pure value into Free */
def pure[S[_], A](a: A): Free[S, A] = Pure(a)

final class FreeInjectPartiallyApplied[F[_], G[_]] private[free] {
def apply[A](fa: F[A])(implicit I : Inject[F, G]): Free[G, A] =
Free.liftF(I.inj(fa))
}

def inject[F[_], G[_]]: FreeInjectPartiallyApplied[F, G] = new FreeInjectPartiallyApplied

/**
* `Free[S, ?]` has a monad for any type constructor `S[_]`.
*/
implicit def freeMonad[S[_]]: MonadRec[Free[S, ?]] =
new MonadRec[Free[S, ?]] {
def pure[A](a: A): Free[S, A] = Free.pure(a)
override def map[A, B](fa: Free[S, A])(f: A => B): Free[S, B] = fa.map(f)
def flatMap[A, B](a: Free[S, A])(f: A => Free[S, B]): Free[S, B] = a.flatMap(f)
def tailRecM[A, B](a: A)(f: A => Free[S, A Xor B]): Free[S, B] =
f(a).flatMap(_ match {
case Xor.Left(a1) => tailRecM(a1)(f) // recursion OK here, since Free is lazy
case Xor.Right(b) => pure(b)
})
}
}

import Free._

/**
* A free operational monad for some functor `S`. Binding is done
* using the heap instead of the stack, allowing tail-call
* elimination.
*/
sealed abstract class Free[S[_], A] extends Product with Serializable {

import Free.{ Pure, Suspend, Gosub }

final def map[B](f: A => B): Free[S, B] =
flatMap(a => Pure(f(a)))

Expand Down Expand Up @@ -115,7 +68,12 @@ sealed abstract class Free[S[_], A] extends Product with Serializable {
loop(this)
}

final def run(implicit S: Comonad[S]): A = go(S.extract)
/**
* Run to completion, using the given comonad to extract the
* resumption.
*/
final def run(implicit S: Comonad[S]): A =
go(S.extract)

/**
* Run to completion, using a function that maps the resumption
Expand All @@ -129,32 +87,142 @@ sealed abstract class Free[S[_], A] extends Product with Serializable {
runM2(this)
}

/**
* Run to completion, using monadic recursion to evaluate the
* resumption in the context of `S`.
*/
final def runTailRec(implicit S: MonadRec[S]): S[A] = {
def step(rma: Free[S, A]): S[Xor[Free[S, A], A]] =
rma match {
case Pure(a) =>
S.pure(Xor.right(a))
case Suspend(ma) =>
S.map(ma)(Xor.right(_))
case Gosub(curr, f) =>
curr match {
case Pure(x) =>
S.pure(Xor.left(f(x)))
case Suspend(mx) =>
S.map(mx)(x => Xor.left(f(x)))
case Gosub(prev, g) =>
S.pure(Xor.left(prev.flatMap(w => g(w).flatMap(f))))
}
}
S.tailRecM(this)(step)
}

/**
* Catamorphism for `Free`.
*
* Run to completion, mapping the suspension with the given transformation at each step and
* accumulating into the monad `M`.
* Run to completion, mapping the suspension with the given
* transformation at each step and accumulating into the monad `M`.
*
* This method uses `MonadRec[M]` to provide stack-safety.
*/
final def foldMap[M[_]](f: FunctionK[S,M])(implicit M: MonadRec[M]): M[A] =
final def foldMap[M[_]](f: FunctionK[S, M])(implicit M: MonadRec[M]): M[A] =
M.tailRecM(this)(_.step match {
case Pure(a) => M.pure(Xor.right(a))
case Suspend(sa) => M.map(f(sa))(Xor.right)
case Gosub(c, g) => M.map(c.foldMap(f))(cc => Xor.left(g(cc)))
})

/**
* Compile your Free into another language by changing the suspension functor
* using the given natural transformation.
* Be careful if your natural transformation is effectful, effects are applied by mapSuspension.
* Compile your free monad into another language by changing the
* suspension functor using the given natural transformation `f`.
*
* If your natural transformation is effectful, be careful. These
* effects will be applied by `compile`.
*/
final def mapSuspension[T[_]](f: FunctionK[S,T]): Free[T, A] =
final def compile[T[_]](f: FunctionK[S, T]): Free[T, A] =
foldMap[Free[T, ?]] {
new FunctionK[S, Free[T, ?]] {
def apply[B](fa: S[B]): Free[T, B] = Suspend(f(fa))
}
}(Free.freeMonad)

final def compile[T[_]](f: FunctionK[S,T]): Free[T, A] = mapSuspension(f)
override def toString(): String =
"Free(...)"
}

object Free {

/**
* Return from the computation with the given value.
*/
private[free] final case class Pure[S[_], A](a: A) extends Free[S, A]

/** Suspend the computation with the given suspension. */
private[free] final case class Suspend[S[_], A](a: S[A]) extends Free[S, A]

/** Call a subroutine and continue with the given function. */
private[free] final case class Gosub[S[_], B, C](c: Free[S, C], f: C => Free[S, B]) extends Free[S, B]

/**
* Lift a pure `A` value into the free monad.
*/
def pure[S[_], A](a: A): Free[S, A] = Pure(a)

/**
* Lift an `F[A]` value into the free monad.
*/
def liftF[F[_], A](value: F[A]): Free[F, A] = Suspend(value)

/**
* Suspend the creation of a `Free[F, A]` value.
*/
def suspend[F[_], A](value: => Free[F, A]): Free[F, A] =
pure(()).flatMap(_ => value)

override def toString(): String = "Free(...)"
/**
* This method is used to defer the application of an Inject[F, G]
* instance. The actual work happens in
* `FreeInjectPartiallyApplied#apply`.
*
* This method exists to allow the `F` and `G` parameters to be
* bound independently of the `A` parameter below.
*/
def inject[F[_], G[_]]: FreeInjectPartiallyApplied[F, G] =
new FreeInjectPartiallyApplied

/**
* Pre-application of an injection to a `F[A]` value.
*/
final class FreeInjectPartiallyApplied[F[_], G[_]] private[free] {
def apply[A](fa: F[A])(implicit I: Inject[F, G]): Free[G, A] =
Free.liftF(I.inj(fa))
}

/**
* `Free[S, ?]` has a monad for any type constructor `S[_]`.
*/
implicit def freeMonad[S[_]]: MonadRec[Free[S, ?]] =
new MonadRec[Free[S, ?]] {
def pure[A](a: A): Free[S, A] = Free.pure(a)
override def map[A, B](fa: Free[S, A])(f: A => B): Free[S, B] = fa.map(f)
def flatMap[A, B](a: Free[S, A])(f: A => Free[S, B]): Free[S, B] = a.flatMap(f)
def tailRecM[A, B](a: A)(f: A => Free[S, A Xor B]): Free[S, B] =
f(a).flatMap(_ match {
case Left(a1) => tailRecM(a1)(f) // recursion OK here, since Free is lazy
case Right(b) => pure(b)
})
}

/**
* Perform a stack-safe monadic fold from the source context `F`
* into the target monad `G`.
*
* This method can express short-circuiting semantics. Even when
* `fa` is an infinite structure, this method can potentially
* terminate if the `foldRight` implementation for `F` and the
* `tailRecM` implementation for `G` are sufficiently lazy.
*/
def foldLeftM[F[_]: Foldable, G[_]: MonadRec, A, B](fa: F[A], z: B)(f: (B, A) => G[B]): G[B] =
unsafeFoldLeftM[F, Free[G, ?], A, B](fa, z) { (b, a) =>
Free.liftF(f(b, a))
}.runTailRec
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did I miss it, or does this method not use the G.tailRecM method at all?

Also, it seems like there is double trampolining (once via Free and once via Eval in F.foldRight)? I wonder if a more direct implementation (using either Free or Eval, but not both) is possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TomasMikula as the author of this code, I'll be the first to say that it is a bit bizarre. G.tailRecM should be being called by the runTailRec call on the Free structure. Eval makes an appearance due to the foldRight type signature, but really we are just using foldRight because of its laziness, and unsafeFoldLeftM is calling .value on the Eval, so we don't actually get any Eval trampolining here. It's effectively like we are using a scalaz-style foldRight that takes a byname parameter, since we aren't taking advantage Eval's trampolining.

If you can figure out a way to provide a more direct implementation, I'd definitely be interested. I think that this has to be done in terms of foldRight if we want laziness, though, which means Eval will definitely be in play. You may be able to use MonadRec and foldRight without Free, and if so that'd be great - as it stands I'm kind of convinced that we should move Free (the monad but not the other free structures) back into core for this sort of thing.

Copy link
Contributor

@ceedubs ceedubs Jun 11, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One other note: it really isn't obvious to me that this implementation should be stack safe, but the tests kind of prove that it is ¯\_(ツ)_/¯

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somehow I completely missed that .runTailRec at the end 😊


private def unsafeFoldLeftM[F[_], G[_], A, B](fa: F[A], z: B)(f: (B, A) => G[B])(implicit F: Foldable[F], G: Monad[G]): G[B] =
F.foldRight(fa, Always((w: B) => G.pure(w))) { (a, lb) =>
Always((w: B) => G.flatMap(f(w, a))(lb.value))
}.value.apply(z)
}
2 changes: 0 additions & 2 deletions free/src/main/scala/cats/free/Trampoline.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package cats
package free

import cats.std.function.catsStdBimonadForFunction0

// To workaround SI-7139 `object Trampoline` needs to be defined inside the package object
// together with the type alias.
private[free] abstract class TrampolineFunctions {
Expand Down
35 changes: 31 additions & 4 deletions free/src/test/scala/cats/free/FreeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class FreeTests extends CatsSuite {
rr.toString.length should be > 0
}

test("mapSuspension id"){
test("compile id"){
forAll { x: Free[List, Int] =>
x.mapSuspension(FunctionK.id[List]) should === (x)
x.compile(FunctionK.id[List]) should === (x)
}
}

Expand All @@ -42,9 +42,9 @@ class FreeTests extends CatsSuite {
val _ = Free.suspend(yikes[Option, Int])
}

test("mapSuspension consistent with foldMap"){
test("compile consistent with foldMap"){
forAll { x: Free[List, Int] =>
val mapped = x.mapSuspension(headOptionU)
val mapped = x.compile(headOptionU)
val folded = mapped.foldMap(FunctionK.id[Option])
folded should === (x.foldMap(headOptionU))
}
Expand Down Expand Up @@ -78,6 +78,33 @@ class FreeTests extends CatsSuite {

assert(10000 == a(0).foldMap(runner))
}

test(".runTailRec") {
val r = Free.pure[List, Int](12358)
def recurse(r: Free[List, Int], n: Int): Free[List, Int] =
if (n > 0) recurse(r.flatMap(x => Free.pure(x + 1)), n - 1) else r
val res = recurse(r, 100000).runTailRec
assert(res == List(112358))
}

test(".foldLeftM") {
// you can see .foldLeftM traversing the entire structure by
// changing the constant argument to .take and observing the time
// this test takes.
val ns = Stream.from(1).take(1000)
val res = Free.foldLeftM[Stream, Xor[Int, ?], Int, Int](ns, 0) { (sum, n) =>
if (sum >= 2) Xor.left(sum) else Xor.right(sum + n)
}
assert(res == Xor.left(3))
}

test(".foldLeftM short-circuiting") {
val ns = Stream.continually(1)
val res = Free.foldLeftM[Stream, Xor[Int, ?], Int, Int](ns, 0) { (sum, n) =>
if (sum >= 100000) Xor.left(sum) else Xor.right(sum + n)
}
assert(res == Xor.left(100000))
}
}

object FreeTests extends FreeTestsInstances {
Expand Down