From d609d261857150005050e5e0e8b3a4ba255695ce Mon Sep 17 00:00:00 2001 From: Vasil Vasilev Date: Wed, 18 Aug 2021 23:44:02 +0200 Subject: [PATCH 1/5] Reimplement `parJoin` --- core/shared/src/main/scala/fs2/Stream.scala | 158 +++++++++++++------- 1 file changed, 100 insertions(+), 58 deletions(-) diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index b409fb7d2a..60bc437b44 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -3860,22 +3860,19 @@ object Stream extends StreamLowPriority { * * @param maxOpen Maximum number of open inner streams at any time. Must be > 0. */ - def parJoin( - maxOpen: Int - )(implicit F: Concurrent[F]): Stream[F, O] = { - assert(maxOpen > 0, "maxOpen must be > 0, was: " + maxOpen) + def parJoin(maxOpen: Int)(implicit F: Concurrent[F]): Stream[F, O] = { + assert(maxOpen > 0, s"maxOpen must be > 0, was: $maxOpen") + if (maxOpen === 1) outer.flatten else { val fstream: F[Stream[F, O]] = for { - done <- SignallingRef(None: Option[Option[Throwable]]) - available <- Semaphore[F](maxOpen.toLong) + done <- SignallingRef(none[Option[Throwable]]) + available <- Semaphore(maxOpen.toLong) // starts with 1 because outer stream is running by default - running <- SignallingRef(1L) - // sync queue assures we won't overload heap when resulting stream is not able to catchup with inner streams - outputChan <- Channel.synchronous[F, Chunk[O]] + running <- SignallingRef(1) + outcomes <- Channel.unbounded[F, F[Unit]] + output <- Channel.synchronous[F, Chunk[O]] } yield { - // stops the join evaluation - // all the streams will be terminated. If err is supplied, that will get attached to any error currently present def stop(rslt: Option[Throwable]): F[Unit] = done.update { case rslt0 @ Some(Some(err0)) => @@ -3883,69 +3880,114 @@ object Stream extends StreamLowPriority { Some(Some(CompositeFailure(err0, err))) } case _ => Some(rslt) - } >> outputChan.close.void - - def untilDone[A](str: Stream[F, A]) = str.interruptWhen(done.map(_.nonEmpty)) + } val incrementRunning: F[Unit] = running.update(_ + 1) val decrementRunning: F[Unit] = - running.modify { n => - val now = n - 1 - now -> (if (now == 0) stop(None) else F.unit) - }.flatten + running + .updateAndGet(_ - 1) + .flatMap(now => if (now == 0) outcomes.close.void else F.unit) - // "block" and await until the `running` counter drops to zero. val awaitWhileRunning: F[Unit] = running.discrete.forall(_ > 0).compile.drain - // signals that a running stream, either inner or ourer, finished with or without failure. - def endWithResult(result: Either[Throwable, Unit]): F[Unit] = - result match { - case Right(()) => decrementRunning - case Left(err) => stop(Some(err)) >> decrementRunning - } - - def sendToChannel(str: Stream[F, O]) = str.chunks.foreach(x => outputChan.send(x).void) - - // runs one inner stream, each stream is forked. - // terminates when killSignal is true - // failures will be propagated through `done` Signal - // note that supplied scope's resources must be leased before the inner stream forks the execution to another thread - // and that it must be released once the inner stream terminates or fails. def runInner(inner: Stream[F, O], outerScope: Scope[F]): F[Unit] = F.uncancelable { _ => - outerScope.lease.flatMap { lease => - available.acquire >> - incrementRunning >> + outerScope.lease + .flatTap(_ => available.acquire >> incrementRunning) + .flatMap { lease => F.start { - // Note that the `interrupt` must be AFTER the send to the sync channel, - // otherwise the process may hang to send last item while being interrupted - val backInsertions = untilDone(sendToChannel(inner)) - for { - r <- backInsertions.compile.drain.attempt - cancelResult <- lease.cancel - _ <- available.release - _ <- endWithResult(CompositeFailure.fromResults(r, cancelResult)) - } yield () + inner.chunks + .evalMap(s => output.send(s).void) + .interruptWhen(done.map(_.nonEmpty)) + .compile + .drain + .guaranteeCase { + case Outcome.Succeeded(fu) => + lease.cancel.flatMap { + case Left(t) => stop(Some(t)) + case Right(()) => outcomes.send(fu).void + } >> available.release >> decrementRunning + + case Outcome.Errored(t) => + lease.cancel.flatMap { cancelResult => + (CompositeFailure.fromResults(Left(t), cancelResult) match { + case Left(t) => stop(Some(t)) + case Right(()) => F.unit + }) + } >> available.release >> decrementRunning + + case Outcome.Canceled() => + lease.cancel.flatMap { + case Left(t) => stop(Some(t)) + case Right(()) => F.unit + } >> available.release >> decrementRunning + } + .attempt + .void }.void - } + } } - def runInnerScope(inner: Stream[F, O]): Stream[F, INothing] = - new Stream(Pull.getScope[F].flatMap((sc: Scope[F]) => Pull.eval(runInner(inner, sc)))) - - // runs the outer stream, interrupts when kill == true, and then decrements the `running` def runOuter: F[Unit] = - untilDone(outer.flatMap(runInnerScope)).compile.drain.attempt.flatMap(endWithResult) + F.uncancelable { _ => + outer + .flatMap(inner => + new Stream( + Pull.getScope[F].flatMap(outerScope => Pull.eval(runInner(inner, outerScope))) + ) + ) + .interruptWhen(done.map(_.nonEmpty)) + .compile + .drain + .guaranteeCase { + case Outcome.Succeeded(fu) => + outcomes.send(fu) >> decrementRunning + + case Outcome.Errored(t) => + stop(Some(t)) >> decrementRunning + + case Outcome.Canceled() => + decrementRunning + } + .attempt + .void + } - // awaits when all streams (outer + inner) finished, - // and then collects result of the stream (outer + inner) execution - def signalResult: F[Unit] = - done.get.flatMap(_.flatten.fold[F[Unit]](F.unit)(F.raiseError)) - val endOuter: F[Unit] = stop(None) >> awaitWhileRunning >> signalResult + def outcomeJoiner: F[Unit] = + outcomes.stream + .evalMap(identity) + .compile + .drain + .guaranteeCase { + case Outcome.Succeeded(_) => + stop(None) >> output.close.void - val backEnqueue = Stream.bracket(F.start(runOuter))(_ => endOuter) + case Outcome.Errored(t) => + stop(Some(t)) >> output.close.void + + case Outcome.Canceled() => + stop(None) >> output.close.void + } + .attempt + .void - backEnqueue >> outputChan.stream.unchunks + def signalResult(fiber: Fiber[F, Throwable, Unit]): F[Unit] = + done.get.flatMap { blah => + blah.flatten.fold[F[Unit]](fiber.joinWithNever)(F.raiseError) + } + + Stream + .bracket(F.start(runOuter) >> F.start(outcomeJoiner)) { fiber => + stop(None) >> + // in case of short-circuiting, the `fiberJoiner` would not have had a chance + // to wait until all fibers have been joined, so we need to do it manually + // by waiting on the counter + awaitWhileRunning >> + signalResult(fiber) + } + .flatMap { _ => + output.stream.flatMap(Stream.chunk(_).covary[F]) + } } Stream.eval(fstream).flatten From 2424255e8b1f0835b627d50ae9a305843457bc06 Mon Sep 17 00:00:00 2001 From: Vasil Vasilev Date: Wed, 18 Aug 2021 23:44:14 +0200 Subject: [PATCH 2/5] Add short circuiting monad transformer tests --- .../test/scala/fs2/StreamParJoinSuite.scala | 123 ++++++++++++++++++ 1 file changed, 123 insertions(+) diff --git a/core/shared/src/test/scala/fs2/StreamParJoinSuite.scala b/core/shared/src/test/scala/fs2/StreamParJoinSuite.scala index 9f64d7e9ff..0719c687db 100644 --- a/core/shared/src/test/scala/fs2/StreamParJoinSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamParJoinSuite.scala @@ -23,11 +23,14 @@ package fs2 import scala.concurrent.duration._ +import cats.data.{EitherT, OptionT} import cats.effect.IO import cats.effect.kernel.{Deferred, Ref} import cats.syntax.all._ import org.scalacheck.effect.PropF.forAllF +import scala.util.control.NoStackTrace + class StreamParJoinSuite extends Fs2Suite { test("no concurrency") { forAllF { (s: Stream[Pure, Int]) => @@ -204,4 +207,124 @@ class StreamParJoinSuite extends Fs2Suite { .parJoinUnbounded ++ Stream.emit(1)).compile.drain .intercept[Err] } + + group("short-circuiting transformers") { + test("do not block while evaluating a stream of streams in IO in parallel") { + def f(n: Int): Stream[IO, String] = Stream(n).map(_.toString) + + Stream(1, 2, 3) + .map(f) + .parJoinUnbounded + .compile + .toList + .map(_.toSet) + .flatMap { actual => + IO(assertEquals(actual, Set("1", "2", "3"))) + } + } + + test( + "do not block while evaluating a stream of streams in EitherT[IO, Throwable, *] in parallel - right" + ) { + def f(n: Int): Stream[EitherT[IO, Throwable, *], String] = Stream(n).map(_.toString) + + Stream(1, 2, 3) + .map(f) + .parJoinUnbounded + .compile + .toList + .map(_.toSet) + .value + .flatMap { actual => + IO(assertEquals(actual, Right(Set("1", "2", "3")))) + } + } + + test( + "do not block while evaluating a stream of streams in EitherT[IO, Throwable, *] in parallel - left" + ) { + case object TestException extends Throwable with NoStackTrace + + def f(n: Int): Stream[EitherT[IO, Throwable, *], String] = + if (n % 2 != 0) Stream(n).map(_.toString) + else Stream.eval[EitherT[IO, Throwable, *], String](EitherT.leftT(TestException)) + + Stream(1, 2, 3) + .map(f) + .parJoinUnbounded + .compile + .toList + .value + .flatMap { actual => + IO(assertEquals(actual, Left(TestException))) + } + } + + test("do not block while evaluating an EitherT.left outer stream") { + case object TestException extends Throwable with NoStackTrace + + def f(n: Int): Stream[EitherT[IO, Throwable, *], String] = Stream(n).map(_.toString) + + Stream + .eval[EitherT[IO, Throwable, *], Int](EitherT.leftT[IO, Int](TestException)) + .map(f) + .parJoinUnbounded + .compile + .toList + .value + .flatMap { actual => + IO(assertEquals(actual, Left(TestException))) + } + } + + test( + "do not block while evaluating a stream of streams in OptionT[IO, *] in parallel - some" + ) { + def f(n: Int): Stream[OptionT[IO, *], String] = Stream(n).map(_.toString) + + Stream(1, 2, 3) + .map(f) + .parJoinUnbounded + .compile + .toList + .map(_.toSet) + .value + .flatMap { actual => + IO(assertEquals(actual, Some(Set("1", "2", "3")))) + } + } + + test( + "do not block while evaluating a stream of streams in OptionT[IO, *] in parallel - none" + ) { + def f(n: Int): Stream[OptionT[IO, *], String] = + if (n % 2 != 0) Stream(n).map(_.toString) + else Stream.eval[OptionT[IO, *], String](OptionT.none) + + Stream(1, 2, 3) + .map(f) + .parJoinUnbounded + .compile + .toList + .value + .flatMap { actual => + IO(assertEquals(actual, None)) + } + } + + test("do not block while evaluating an OptionT.none outer stream") { + def f(n: Int): Stream[OptionT[IO, *], String] = Stream(n).map(_.toString) + + Stream + .eval[OptionT[IO, *], Int](OptionT.none[IO, Int]) + .map(f) + .parJoinUnbounded + .compile + .toList + .value + .flatMap { actual => + IO(assertEquals(actual, None)) + } + } + } } From 31ce7daf3450f9986fe4d7ee52466f5a441060d6 Mon Sep 17 00:00:00 2001 From: Vasil Vasilev Date: Fri, 10 Sep 2021 14:13:13 +0200 Subject: [PATCH 3/5] Implement `onOutcome` --- core/shared/src/main/scala/fs2/Stream.scala | 55 +++++++++------------ 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 60bc437b44..f2bac6232e 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -3890,6 +3890,23 @@ object Stream extends StreamLowPriority { val awaitWhileRunning: F[Unit] = running.discrete.forall(_ > 0).compile.drain + def onOutcome( + oc: Outcome[F, Throwable, Unit], + cancelResult: Either[Throwable, Unit] + ): F[Unit] = + oc match { + case Outcome.Succeeded(fu) => + cancelResult.fold(t => stop(Some(t)), _ => outcomes.send(fu).void) + + case Outcome.Errored(t) => + CompositeFailure + .fromResults(Left(t), cancelResult) + .fold(t => stop(Some(t)), _ => F.unit) + + case Outcome.Canceled() => + cancelResult.fold(t => stop(Some(t)), _ => F.unit) + } + def runInner(inner: Stream[F, O], outerScope: Scope[F]): F[Unit] = F.uncancelable { _ => outerScope.lease @@ -3901,33 +3918,18 @@ object Stream extends StreamLowPriority { .interruptWhen(done.map(_.nonEmpty)) .compile .drain - .guaranteeCase { - case Outcome.Succeeded(fu) => - lease.cancel.flatMap { - case Left(t) => stop(Some(t)) - case Right(()) => outcomes.send(fu).void - } >> available.release >> decrementRunning - - case Outcome.Errored(t) => - lease.cancel.flatMap { cancelResult => - (CompositeFailure.fromResults(Left(t), cancelResult) match { - case Left(t) => stop(Some(t)) - case Right(()) => F.unit - }) - } >> available.release >> decrementRunning - - case Outcome.Canceled() => - lease.cancel.flatMap { - case Left(t) => stop(Some(t)) - case Right(()) => F.unit - } >> available.release >> decrementRunning - } + .guaranteeCase(oc => + lease.cancel + .flatMap(onOutcome(oc, _)) >> available.release >> decrementRunning + ) .attempt .void }.void } } + val RightUnit = Right(()) + def runOuter: F[Unit] = F.uncancelable { _ => outer @@ -3939,16 +3941,7 @@ object Stream extends StreamLowPriority { .interruptWhen(done.map(_.nonEmpty)) .compile .drain - .guaranteeCase { - case Outcome.Succeeded(fu) => - outcomes.send(fu) >> decrementRunning - - case Outcome.Errored(t) => - stop(Some(t)) >> decrementRunning - - case Outcome.Canceled() => - decrementRunning - } + .guaranteeCase(onOutcome(_, RightUnit) >> decrementRunning) .attempt .void } From 17c4416db396a10d81f603a0062840396f167a14 Mon Sep 17 00:00:00 2001 From: Vasil Vasilev Date: Fri, 10 Sep 2021 14:25:06 +0200 Subject: [PATCH 4/5] Drain before interruption --- core/shared/src/main/scala/fs2/Stream.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index f2bac6232e..30cfabe1bd 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -3914,7 +3914,7 @@ object Stream extends StreamLowPriority { .flatMap { lease => F.start { inner.chunks - .evalMap(s => output.send(s).void) + .foreach(s => output.send(s).void) .interruptWhen(done.map(_.nonEmpty)) .compile .drain @@ -3938,6 +3938,7 @@ object Stream extends StreamLowPriority { Pull.getScope[F].flatMap(outerScope => Pull.eval(runInner(inner, outerScope))) ) ) + .drain .interruptWhen(done.map(_.nonEmpty)) .compile .drain From d3705c5db8b21b61083c6e5c5a413329e621a526 Mon Sep 17 00:00:00 2001 From: Vasil Vasilev Date: Fri, 10 Sep 2021 14:26:50 +0200 Subject: [PATCH 5/5] Handle error instead of attempt and void --- core/shared/src/main/scala/fs2/Stream.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 30cfabe1bd..4a4f2e896e 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -3922,8 +3922,7 @@ object Stream extends StreamLowPriority { lease.cancel .flatMap(onOutcome(oc, _)) >> available.release >> decrementRunning ) - .attempt - .void + .handleError(_ => ()) }.void } } @@ -3943,8 +3942,7 @@ object Stream extends StreamLowPriority { .compile .drain .guaranteeCase(onOutcome(_, RightUnit) >> decrementRunning) - .attempt - .void + .handleError(_ => ()) } def outcomeJoiner: F[Unit] = @@ -3962,8 +3960,7 @@ object Stream extends StreamLowPriority { case Outcome.Canceled() => stop(None) >> output.close.void } - .attempt - .void + .handleError(_ => ()) def signalResult(fiber: Fiber[F, Throwable, Unit]): F[Unit] = done.get.flatMap { blah =>