Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix soundness issue with stream interruption + translation #2145

Merged
merged 10 commits into from
Nov 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions core/shared/src/main/scala/fs2/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ private[fs2] trait CompilerLowPriority2 {
init: B
)(foldChunk: (B, Chunk[O]) => B): Resource[F, B] =
Resource
.makeCase(CompileScope.newRoot[F])((scope, ec) => scope.close(ec).rethrow)
.makeCase(Scope.newRoot[F])((scope, ec) => scope.close(ec).rethrow)
.evalMap(scope => Pull.compile(stream, scope, true, init)(foldChunk))
}
}
Expand Down Expand Up @@ -152,7 +152,7 @@ object Compiler extends CompilerLowPriority {
init: Out,
foldChunk: (Out, Chunk[O]) => Out
): F[Out] =
CompileScope
Scope
.newRoot[F](this)
.flatMap(scope =>
Pull
Expand Down Expand Up @@ -193,7 +193,7 @@ object Compiler extends CompilerLowPriority {
foldChunk: (Out, Chunk[O]) => Out
): F[Out] =
Resource
.makeCase(CompileScope.newRoot[F](this))((scope, ec) => scope.close(ec).rethrow)
.makeCase(Scope.newRoot[F](this))((scope, ec) => scope.close(ec).rethrow)
.use(scope => Pull.compile[F, O, Out](p, scope, false, init)(foldChunk))
}

Expand Down
73 changes: 51 additions & 22 deletions core/shared/src/main/scala/fs2/Pull.scala
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ object Pull extends PullLowPriority {
)(implicit F: MonadError[F, Throwable]): Pull[F, INothing, Stream[F, O]] =
for {
scope <- Pull.getScope[F]
lease <- Pull.eval(scope.leaseOrError)
lease <- Pull.eval(scope.lease)
} yield s.onFinalize(lease.cancel.redeemWith(F.raiseError(_), _ => F.unit))

/** Repeatedly uses the output of the pull as input for the next step of the pull.
Expand Down Expand Up @@ -269,7 +269,7 @@ object Pull extends PullLowPriority {
/** Gets the current scope, allowing manual leasing or interruption.
* This is a low-level method and generally should not be used by user code.
*/
def getScope[F[_]]: Pull[F, INothing, Scope[F]] = GetScope[F]()
private[fs2] def getScope[F[_]]: Pull[F, INothing, Scope[F]] = GetScope[F]()

/** Returns a pull that evaluates the supplied by-name each time the pull is used,
* allowing use of a mutable value in pull computations.
Expand Down Expand Up @@ -528,6 +528,9 @@ object Pull extends PullLowPriority {
useInterruption: Boolean
) extends Action[F, O, Unit]

private final case class InterruptWhen[+F[_]](haltOnSignal: F[Either[Throwable, Unit]])
extends AlgEffect[F, Unit]

// `InterruptedScope` contains id of the scope currently being interrupted
// together with any errors accumulated during interruption process
private final case class CloseScope(
Expand All @@ -536,8 +539,7 @@ object Pull extends PullLowPriority {
exitCase: ExitCase
) extends AlgEffect[Pure, Unit]

private final case class GetScope[F[_]]() extends AlgEffect[Pure, CompileScope[F]]
private[fs2] def getScopeInternal[F[_]]: Pull[Pure, INothing, CompileScope[F]] = GetScope[F]()
private final case class GetScope[F[_]]() extends AlgEffect[Pure, Scope[F]]

private[fs2] def stepLeg[F[_], O](
leg: Stream.StepLeg[F, O]
Expand All @@ -558,19 +560,23 @@ object Pull extends PullLowPriority {
*/
private[fs2] def interruptScope[F[_], O](s: Pull[F, O, Unit]): Pull[F, O, Unit] = InScope(s, true)

private[fs2] def interruptWhen[F[_], O](
haltOnSignal: F[Either[Throwable, Unit]]
): Pull[F, O, Unit] = InterruptWhen(haltOnSignal)

private[fs2] def uncons[F[_], X, O](
s: Pull[F, O, Unit]
): Pull[F, X, Option[(Chunk[O], Pull[F, O, Unit])]] =
Step(s, None).map(_.map { case (h, _, t) => (h, t.asInstanceOf[Pull[F, O, Unit]]) })

/* Left-folds the output of a stream.
*
* Interruption of the stream is tightly coupled between Pull and CompileScope.
* Interruption of the stream is tightly coupled between Pull and Scope.
* Reason for this is unlike interruption of `F` type (e.g. IO) we need to find
* recovery point where stream evaluation has to continue in Stream algebra.
*
* As such the `Token` is passed to Result.Interrupted as glue between Pull that allows pass-along
* the information to correctly compute recovery point after interruption was signalled via `CompileScope`.
* the information to correctly compute recovery point after interruption was signalled via `Scope`.
*
* This token indicates scope of the computation where interruption actually happened.
* This is used to precisely find most relevant interruption scope where interruption shall be resumed
Expand All @@ -581,16 +587,16 @@ object Pull extends PullLowPriority {
*/
private[fs2] def compile[F[_], O, B](
stream: Pull[F, O, Unit],
initScope: CompileScope[F],
initScope: Scope[F],
extendLastTopLevelScope: Boolean,
init: B
)(foldChunk: (B, Chunk[O]) => B)(implicit
F: MonadError[F, Throwable]
): F[B] = {

sealed trait R[+G[_], +X]
case class Done(scope: CompileScope[F]) extends R[Pure, INothing]
case class Out[+G[_], +X](head: Chunk[X], scope: CompileScope[F], tail: Pull[G, X, Unit])
case class Done(scope: Scope[F]) extends R[Pure, INothing]
case class Out[+G[_], +X](head: Chunk[X], scope: Scope[F], tail: Pull[G, X, Unit])
extends R[G, X]
case class Interrupted(scopeId: Token, err: Option[Throwable]) extends R[Pure, INothing]

Expand All @@ -603,19 +609,19 @@ object Pull extends PullLowPriority {
case Interrupted(scopeId, err) => interrupted(scopeId, err)
}

def done(scope: CompileScope[F]): End
def out(head: Chunk[X], scope: CompileScope[F], tail: Pull[G, X, Unit]): End
def done(scope: Scope[F]): End
def out(head: Chunk[X], scope: Scope[F], tail: Pull[G, X, Unit]): End
def interrupted(scopeId: Token, err: Option[Throwable]): End
}

def go[G[_], X](
scope: CompileScope[F],
extendedTopLevelScope: Option[CompileScope[F]],
scope: Scope[F],
extendedTopLevelScope: Option[Scope[F]],
translation: G ~> F,
stream: Pull[G, X, Unit]
): F[R[G, X]] = {

def interruptGuard(scope: CompileScope[F], view: Cont[Nothing, G, X])(
def interruptGuard(scope: Scope[F], view: Cont[Nothing, G, X])(
next: => F[R[G, X]]
): F[R[G, X]] =
scope.isInterrupted.flatMap {
Expand Down Expand Up @@ -655,10 +661,10 @@ object Pull extends PullLowPriority {

def viewRunner(view: Cont[Unit, G, X]): RunR[G, X, F[R[G, X]]] =
new RunR[G, X, F[R[G, X]]] {
def done(doneScope: CompileScope[F]): F[R[G, X]] =
def done(doneScope: Scope[F]): F[R[G, X]] =
go(doneScope, extendedTopLevelScope, translation, view(Result.unit))

def out(head: Chunk[X], scope: CompileScope[F], tail: Pull[G, X, Unit]): F[R[G, X]] = {
def out(head: Chunk[X], scope: Scope[F], tail: Pull[G, X, Unit]): F[R[G, X]] = {
val contTail = new Bind[G, X, Unit, Unit](tail) {
def cont(r: Result[Unit]) = view(r)
}
Expand Down Expand Up @@ -695,13 +701,13 @@ object Pull extends PullLowPriority {

class StepRunR() extends RunR[G, Y, F[R[G, X]]] {

def done(scope: CompileScope[F]): F[R[G, X]] =
def done(scope: Scope[F]): F[R[G, X]] =
interruptGuard(scope, view) {
val result = Result.Succeeded(None)
go(scope, extendedTopLevelScope, translation, view(result))
}

def out(head: Chunk[Y], outScope: CompileScope[F], tail: Pull[G, Y, Unit]): F[R[G, X]] = {
def out(head: Chunk[Y], outScope: Scope[F], tail: Pull[G, Y, Unit]): F[R[G, X]] = {
// if we originally swapped scopes we want to return the original
// scope back to the go as that is the scope that is expected to be here.
val nextScope = if (u.scope.isEmpty) outScope else scope
Expand All @@ -717,7 +723,7 @@ object Pull extends PullLowPriority {
}

// if scope was specified in step, try to find it, otherwise use the current scope.
val stepScopeF: F[CompileScope[F]] = u.scope match {
val stepScopeF: F[Scope[F]] = u.scope match {
case None => F.pure(scope)
case Some(scopeId) => scope.shiftScope(scopeId, u.toString)
}
Expand Down Expand Up @@ -760,6 +766,26 @@ object Pull extends PullLowPriority {
interruptGuard(scope, view)(cont)
}

def goInterruptWhen(
haltOnSignal: F[Either[Throwable, Unit]],
view: Cont[Unit, G, X]
): F[R[G, X]] = {
val onScope = scope.acquireResource(
_ => scope.interruptWhen(haltOnSignal),
(f: Fiber[F, Throwable, Unit], _: ExitCase) => f.cancel
)
val cont = onScope.flatMap { outcome =>
val result = outcome match {
case Outcome.Succeeded(Right(_)) => Result.Succeeded(())
case Outcome.Succeeded(Left(scopeId)) => Result.Interrupted(scopeId, None)
case Outcome.Canceled() => Result.Interrupted(scope.id, None)
case Outcome.Errored(err) => Result.Fail(err)
}
go(scope, extendedTopLevelScope, translation, view(result))
}
interruptGuard(scope, view)(cont)
}

def goInScope(
stream: Pull[G, X, Unit],
useInterruption: Boolean,
Expand Down Expand Up @@ -803,7 +829,7 @@ object Pull extends PullLowPriority {
}

def goCloseScope(close: CloseScope, view: Cont[Unit, G, X]): F[R[G, X]] = {
def closeAndGo(toClose: CompileScope[F]) =
def closeAndGo(toClose: Scope[F]) =
toClose.close(close.exitCase).flatMap { r =>
toClose.openAncestor.flatMap { ancestor =>
val res = close.interruption match {
Expand All @@ -824,7 +850,7 @@ object Pull extends PullLowPriority {
}
}

val scopeToClose: F[Option[CompileScope[F]]] = scope
val scopeToClose: F[Option[Scope[F]]] = scope
.findSelfOrAncestor(close.scopeId)
.pure[F]
.orElse(scope.findSelfOrChild(close.scopeId))
Expand Down Expand Up @@ -896,6 +922,9 @@ object Pull extends PullLowPriority {
val uu = inScope.stream.asInstanceOf[Pull[g, X, Unit]]
goInScope(uu, inScope.useInterruption, view.asInstanceOf[View[g, X, Unit]])

case int: InterruptWhen[g] =>
goInterruptWhen(translation(int.haltOnSignal), view)

case close: CloseScope =>
goCloseScope(close, view.asInstanceOf[View[G, X, Unit]])
}
Expand All @@ -904,7 +933,7 @@ object Pull extends PullLowPriority {

val initFk: F ~> F = cats.arrow.FunctionK.id[F]

def outerLoop(scope: CompileScope[F], accB: B, stream: Pull[F, O, Unit]): F[B] =
def outerLoop(scope: Scope[F], accB: B, stream: Pull[F, O, Unit]): F[B] =
go[F, O](scope, None, initFk, stream).flatMap {
case Done(_) => F.pure(accB)
case out: Out[f, o] =>
Expand Down
87 changes: 0 additions & 87 deletions core/shared/src/main/scala/fs2/Scope.scala

This file was deleted.

11 changes: 4 additions & 7 deletions core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1603,10 +1603,7 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
def interruptWhen[F2[x] >: F[x]: Concurrent](
haltOnSignal: F2[Either[Throwable, Unit]]
): Stream[F2, O] =
Stream
.getScope[F2]
.flatMap(scope => Stream.supervise(haltOnSignal.flatMap(scope.interrupt)) >> this)
.interruptScope
(Pull.interruptWhen(haltOnSignal) >> this.pull.echo).stream.interruptScope

/** Creates a scope that may be interrupted by calling scope#interrupt.
*/
Expand Down Expand Up @@ -3202,7 +3199,7 @@ object Stream extends StreamLowPriority {
/** Gets the current scope, allowing manual leasing or interruption.
* This is a low-level method and generally should not be used by user code.
*/
def getScope[F[x] >: Pure[x]]: Stream[F, Scope[F]] =
private def getScope[F[x] >: Pure[x]]: Stream[F, Scope[F]] =
new Stream(Pull.getScope[F].flatMap(Pull.output1(_)))

/** A stream that never emits and never terminates.
Expand Down Expand Up @@ -3644,7 +3641,7 @@ object Stream extends StreamLowPriority {
// and that it must be released once the inner stream terminates or fails.
def runInner(inner: Stream[F, O], outerScope: Scope[F]): F[Unit] =
F.uncancelable { _ =>
outerScope.leaseOrError.flatMap { lease =>
outerScope.lease.flatMap { lease =>
available.acquire >>
incrementRunning >>
F.start {
Expand Down Expand Up @@ -4006,7 +4003,7 @@ object Stream extends StreamLowPriority {
* If you are not pulling from multiple streams, consider using `uncons`.
*/
def stepLeg: Pull[F, INothing, Option[StepLeg[F, O]]] =
Pull.getScopeInternal[F].flatMap { scope =>
Pull.getScope[F].flatMap { scope =>
new StepLeg[F, O](Chunk.empty, scope.id, self.underlying).stepLeg
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
package fs2.internal

import cats.{Applicative, Id}
import cats.effect.kernel.{Concurrent, Deferred, Outcome, Ref}
import cats.effect.kernel.{Concurrent, Deferred, Fiber, Outcome, Ref}
import cats.effect.kernel.implicits._
import cats.syntax.all._
import InterruptContext.InterruptionOutcome
Expand All @@ -45,9 +45,12 @@ final private[fs2] case class InterruptContext[F[_]](
cancelParent: F[Unit]
)(implicit F: Concurrent[F]) { self =>

def complete(outcome: InterruptionOutcome): F[Unit] =
private def complete(outcome: InterruptionOutcome): F[Unit] =
ref.update(_.orElse(Some(outcome))).guarantee(deferred.complete(outcome).void)

def completeWhen(outcome: F[InterruptionOutcome]): F[Fiber[F, Throwable, Unit]] =
F.start(outcome.flatMap(complete))

/** Creates a [[InterruptContext]] for a child scope which can be interruptible as well.
*
* In case the child scope is interruptible, this will ensure that this scope interrupt will
Expand Down
Loading