Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mapPar for collections #52

Merged
merged 8 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ val result: (Int, String) = par(computation1)(computation2)

If one of the computations fails, the other is interrupted, and `par` waits until both branches complete.

## Parallelize collection transformation

```scala
import ox.mapPar

val input: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

val result: List[Int] = mapPar(input)(4)(_ + 1)
// (2, 3, 4, 5, 6, 7, 8, 9, 10)
```

If any transformation fails, others are interrupted and `mapPar` throws exception. Parallelism
limits how many concurrent forks are going to process the collection.

## Race two computations

```scala
Expand Down
13 changes: 13 additions & 0 deletions core/src/main/scala/ox/mapPar.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package ox

import scala.collection.IterableFactory

def mapPar[I, O, C[E] <: Iterable[E]](parallelism: Int)(iterable: => C[I])(transform: I => O): C[O] =
val workers = Math.min(parallelism, iterable.size)
val elementsInSlide = Math.ceil(iterable.size.toDouble / workers).toInt
val subCollections = iterable.sliding(elementsInSlide, elementsInSlide)

supervised {
val forks = subCollections.toList.map(s => fork(s.map(transform)))
forks.flatMap(_.join()).to(iterable.iterableFactory.asInstanceOf[IterableFactory[C]])
}
5 changes: 5 additions & 0 deletions core/src/main/scala/ox/syntax.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ object syntax:
def forkDaemon: Fork[T] = ox.forkDaemon(f)
def forkUnsupervised: Fork[T] = ox.forkUnsupervised(f)
def forkCancellable: CancellableFork[T] = ox.forkCancellable(f)

extension [T](f: => T)
def timeout(duration: FiniteDuration): T = ox.timeout(duration)(f)
def timeoutOption(duration: FiniteDuration): Option[T] = ox.timeoutOption(duration)(f)
def scopedWhere[U](fl: ForkLocal[U], u: U): T = fl.scopedWhere(u)(f)
Expand All @@ -24,3 +26,6 @@ object syntax:
def useInScope: T = ox.useCloseableInScope(f)
def useScoped[U](p: T => U): U = ox.useScoped(f)(p)
def useSupervised[U](p: T => U): U = ox.useSupervised(f)(p)

extension [I, C[E] <: Iterable[E]](f: => C[I])
def mapParWith[O](parallelism: Int)(transform: I => O) = ox.mapPar(parallelism)(f)(transform)
79 changes: 79 additions & 0 deletions core/src/test/scala/ox/MapParTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package ox

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.syntax.mapParWith
import ox.util.Trail

import scala.collection.IterableFactory
import scala.collection.immutable.Iterable
import scala.List

class MapParTest extends AnyFlatSpec with Matchers {
"mapPar" should "output the same type as input" in {
val input = List(1, 2, 3)
val result = input.mapParWith(1)(identity)
result shouldBe a[List[_]]
}

it should "run computations in parallel" in {
val InputElements = 17
val TransformationMillis: Long = 100

val input = (0 to InputElements)
def transformation(i: Int) = {
Thread.sleep(TransformationMillis)
i + 1
}

val start = System.currentTimeMillis()
val result = input.to(Iterable).mapParWith(5)(transformation)
val end = System.currentTimeMillis()

result.toList should contain theSameElementsInOrderAs (input.map(_ + 1))
(end - start) should be < (InputElements * TransformationMillis)
}

it should "run not more computations than limit" in {
val Parallelism = 5

val input = (1 to 17)

def transformation(i: Int) = {
Thread.currentThread().threadId()
}

val result = input.to(Iterable).mapParWith(Parallelism)(transformation)
result.toSet.size shouldBe Parallelism
}

it should "interrupt other computations in one fails" in {
val InputElements = 18
val TransformationMillis: Long = 100
val trail = Trail()

val input = (0 to InputElements)

def transformation(i: Int) = {
if (i == 4) {
trail.add("exception")
throw new Exception("boom")
} else {
Thread.sleep(TransformationMillis)
trail.add("transformation")
i + 1
}
}

try {
input.to(Iterable).mapParWith(5)(transformation)
} catch {
case e: Exception if e.getMessage == "boom" => trail.add("catch")
}

Thread.sleep(300)
trail.add("all done")

trail.get shouldBe Vector("exception", "catch", "all done")
}
}