diff --git a/core/src/main/scala/cats/Eval.scala b/core/src/main/scala/cats/Eval.scala index dfe03f7d1e..6fe89286d9 100644 --- a/core/src/main/scala/cats/Eval.scala +++ b/core/src/main/scala/cats/Eval.scala @@ -1,5 +1,7 @@ package cats +import data.Xor +import scala.reflect.ClassTag import scala.annotation.tailrec import cats.syntax.all._ @@ -104,6 +106,20 @@ sealed abstract class Eval[A] extends Serializable { self => * Later[A] with an equivalent computation will be returned. */ def memoize: Eval[A] + + /** + * Returns a new Eval which will catch any non-fatal exception + * thrown by the running of the computation. + */ + def catchNonFatal: Eval[Throwable Xor A] = + Eval.always(Xor.catchNonFatal(value)) + + /** + * Returns a new Eval catch some subset of the exceptions that might + * be thrown by the running of the computation. + */ + def catchOnly[T >: Null <: Throwable: ClassTag]: Eval[T Xor A] = + Eval.always(Xor.catchOnly[T](value)) } diff --git a/jvm/src/main/scala/cats/jvm/EvalAsync.scala b/jvm/src/main/scala/cats/jvm/EvalAsync.scala new file mode 100644 index 0000000000..2814c25545 --- /dev/null +++ b/jvm/src/main/scala/cats/jvm/EvalAsync.scala @@ -0,0 +1,41 @@ +package cats +package jvm + +import data.Reader +import java.util.concurrent.{Callable, CountDownLatch, ExecutorService} + +object EvalAsync { + def apply[A](cb: (A => Unit) => Unit): Eval[A] = Eval.always { + val cdl = new CountDownLatch(1) + var result: Option[A] = None + cb((a: A) => {result = Some(a); cdl.countDown}) + cdl.await + result.get // YOLO + } + + implicit class EvalFork[A](val eval: Eval[A]) extends AnyVal { + def callable(cb: A => Unit): Callable[Unit] = new Callable[Unit] { + def call = cb(eval.value) + } + + /** + * Returns an Eval that will produce the same value as the wrapped + * Eval, but extracting the value from the resulting Eval will + * submit the work to the given ExecutorService + */ + def fork: Reader[ExecutorService, Eval[A]] = + Reader { pool => + EvalAsync { cb => + val _ = pool.submit(callable(cb)) + } + } + + /** + * Run this computation asynchronously, and call the callback with the result + */ + final def asyncValue(cb: A => Unit): Reader[ExecutorService, Unit] = + Reader { pool => + val _ = pool.submit(callable(cb)) + } + } +} diff --git a/jvm/src/test/scala/cats/tests/EvalAsyncTests.scala b/jvm/src/test/scala/cats/tests/EvalAsyncTests.scala new file mode 100644 index 0000000000..60f2e59a52 --- /dev/null +++ b/jvm/src/test/scala/cats/tests/EvalAsyncTests.scala @@ -0,0 +1,66 @@ +package cats +package jvm +package tests + +import cats.tests.CatsSuite + +class EvalAsyncTests extends CatsSuite { + test("evalAsync should be stack-safe") { + import data.Streaming + val ones = List(Eval.now(1), + Eval.later(1), + Eval.always(1), + EvalAsync[Int](_(1))) + + val onesStream: Streaming[Eval[Int]] = Streaming.continually(Streaming.fromList(ones)).flatMap(x => x) + + + def taskMap2[A,B,C](t1: Eval[A], t2: Eval[B])(f: (A,B) => C): Eval[C] = { + t1.flatMap(a => t2.map(b => f(a,b))) + } + + def sequenceStreaming[A](fa: Streaming[Eval[A]]): Eval[Streaming[A]] = { + fa.foldRight(Eval.later(Eval.now(Streaming.empty[A])))((a, st) => + st.map(b => taskMap2(b,a)((x,y) => Streaming.cons(y,x)))).value + } + + val howmany = 1000000 + + sequenceStreaming(onesStream.take(howmany)).value.foldLeft(0)((x, _) => x + 1) should be (howmany) + + onesStream.take(howmany).sequence.value.foldLeft(0)((x, _) => x + 1) should be (howmany) + } + + test("EvalAsync should run forked tasks on another thread") { + import EvalAsync._ + + val pool = new java.util.concurrent.ForkJoinPool + + var time1: Long = 0 + var time2: Long = 0 + + val t1: Eval[Unit] = Eval.later { + Thread.sleep(2000) + time1 = System.currentTimeMillis + }.fork.run(pool) + + val t2: Eval[Unit] = Eval.later { + Thread.sleep(1000) + time2 = System.currentTimeMillis + () + }.fork.run(pool) + + val cdl = new java.util.concurrent.CountDownLatch(2) + + t1.asyncValue(_ => cdl.countDown).run(pool) + t2.asyncValue(_ => cdl.countDown).run(pool) + + time1 should be(0L) + time2 should be(0L) + + cdl.await + + time2 should be > 0L + time1 should be > time2 + } +} diff --git a/tests/src/test/scala/cats/tests/EvalTests.scala b/tests/src/test/scala/cats/tests/EvalTests.scala index 46c490898d..4cdc0fda39 100644 --- a/tests/src/test/scala/cats/tests/EvalTests.scala +++ b/tests/src/test/scala/cats/tests/EvalTests.scala @@ -91,6 +91,20 @@ class EvalTests extends CatsSuite { } } + test("eval should be stack-safe") { + val ones = List(Eval.now(1), + Eval.later(1), + Eval.always(1)) + import data.Streaming + // an infinite stream of ones + val onesStream: Streaming[Eval[Int]] = Streaming.continually(Streaming.fromList(ones)).flatMap(x => x) + + val howmany = 1000000 + onesStream.take(howmany).sequence.value.foldLeft(0)((x, _) => x + 1) should be (howmany) + + } + + { implicit val iso = CartesianTests.Isomorphisms.invariant[Eval] checkAll("Eval[Int]", BimonadTests[Eval].bimonad[Int, Int, Int])