Skip to content

Commit 299e03f

Browse files
authored
Merge pull request #52 from softwaremill/mapPar
mapPar for collections
2 parents f28aaba + 4aa8a6a commit 299e03f

File tree

4 files changed

+145
-0
lines changed

4 files changed

+145
-0
lines changed

README.md

+15
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,21 @@ val result: (Int, String) = par(computation1)(computation2)
4545

4646
If one of the computations fails, the other is interrupted, and `par` waits until both branches complete.
4747

48+
## Parallelize collection transformation
49+
50+
```scala
51+
import ox.mapPar
52+
53+
val input: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
54+
55+
val result: List[Int] = mapPar(input)(4)(_ + 1)
56+
// (2, 3, 4, 5, 6, 7, 8, 9, 10)
57+
```
58+
59+
If any transformation fails, others are interrupted and `mapPar` rethrows exception that was
60+
thrown by the transformation. Parallelism
61+
limits how many concurrent forks are going to process the collection.
62+
4863
## Race two computations
4964

5065
```scala

core/src/main/scala/ox/mapPar.scala

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package ox
2+
3+
import java.util.concurrent.Semaphore
4+
import scala.collection.IterableFactory
5+
6+
/** Runs parallel transformations on `iterable`. Using not more than `parallelism` forks concurrently.
7+
*
8+
* @param parallelism maximum number of concurrent forks
9+
* @param iterable collection to transform
10+
* @param transform transformation to apply to each element of `iterable`
11+
*/
12+
def mapPar[I, O, C[E] <: Iterable[E]](parallelism: Int)(iterable: => C[I])(transform: I => O): C[O] =
13+
val s = Semaphore(parallelism)
14+
15+
supervised {
16+
val forks = iterable.map { elem =>
17+
s.acquire()
18+
fork {
19+
val o = transform(elem)
20+
s.release()
21+
o
22+
}
23+
}
24+
forks.toSeq.map(f => f.join()).to(iterable.iterableFactory.asInstanceOf[IterableFactory[C]])
25+
}

core/src/main/scala/ox/syntax.scala

+5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ object syntax:
1212
def forkDaemon: Fork[T] = ox.forkDaemon(f)
1313
def forkUnsupervised: Fork[T] = ox.forkUnsupervised(f)
1414
def forkCancellable: CancellableFork[T] = ox.forkCancellable(f)
15+
16+
extension [T](f: => T)
1517
def timeout(duration: FiniteDuration): T = ox.timeout(duration)(f)
1618
def timeoutOption(duration: FiniteDuration): Option[T] = ox.timeoutOption(duration)(f)
1719
def scopedWhere[U](fl: ForkLocal[U], u: U): T = fl.scopedWhere(u)(f)
@@ -24,3 +26,6 @@ object syntax:
2426
def useInScope: T = ox.useCloseableInScope(f)
2527
def useScoped[U](p: T => U): U = ox.useScoped(f)(p)
2628
def useSupervised[U](p: T => U): U = ox.useSupervised(f)(p)
29+
30+
extension [I, C[E] <: Iterable[E]](f: => C[I])
31+
def mapPar[O](parallelism: Int)(transform: I => O) = ox.mapPar(parallelism)(f)(transform)
+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package ox
2+
3+
import org.scalatest.flatspec.AnyFlatSpec
4+
import org.scalatest.matchers.should.Matchers
5+
import ox.syntax.mapPar
6+
import ox.util.Trail
7+
8+
import java.util.concurrent.atomic.AtomicInteger
9+
import scala.collection.IterableFactory
10+
import scala.collection.immutable.Iterable
11+
import scala.List
12+
13+
class MapParTest extends AnyFlatSpec with Matchers {
14+
"mapPar" should "output the same type as input" in {
15+
val input = List(1, 2, 3)
16+
val result = input.mapPar(1)(identity)
17+
result shouldBe a[List[_]]
18+
}
19+
20+
it should "run computations in parallel" in {
21+
val InputElements = 17
22+
val TransformationMillis: Long = 100
23+
24+
val input = (0 to InputElements)
25+
def transformation(i: Int) = {
26+
Thread.sleep(TransformationMillis)
27+
i + 1
28+
}
29+
30+
val start = System.currentTimeMillis()
31+
val result = input.to(Iterable).mapPar(5)(transformation)
32+
val end = System.currentTimeMillis()
33+
34+
result.toList should contain theSameElementsInOrderAs (input.map(_ + 1))
35+
(end - start) should be < (InputElements * TransformationMillis)
36+
}
37+
38+
it should "run not more computations than limit" in {
39+
val Parallelism = 5
40+
41+
val input = (1 to 158)
42+
43+
class MaxCounter {
44+
val counter = new AtomicInteger(0)
45+
var max = 0
46+
def increment() = {
47+
counter.updateAndGet { c =>
48+
val inc = c + 1
49+
max = if (inc > max) inc else max
50+
inc
51+
}
52+
}
53+
def decrement() = {
54+
counter.decrementAndGet()
55+
}
56+
}
57+
58+
val maxCounter = new MaxCounter
59+
60+
def transformation(i: Int) = {
61+
maxCounter.increment()
62+
Thread.sleep(10)
63+
maxCounter.decrement()
64+
}
65+
66+
input.to(Iterable).mapPar(Parallelism)(transformation)
67+
68+
maxCounter.max should be <= Parallelism
69+
}
70+
71+
it should "interrupt other computations in one fails" in {
72+
val InputElements = 18
73+
val TransformationMillis: Long = 100
74+
val trail = Trail()
75+
76+
val input = (0 to InputElements)
77+
78+
def transformation(i: Int) = {
79+
if (i == 4) {
80+
trail.add("exception")
81+
throw new Exception("boom")
82+
} else {
83+
Thread.sleep(TransformationMillis)
84+
trail.add("transformation")
85+
i + 1
86+
}
87+
}
88+
89+
try {
90+
input.to(Iterable).mapPar(5)(transformation)
91+
} catch {
92+
case e: Exception if e.getMessage == "boom" => trail.add("catch")
93+
}
94+
95+
Thread.sleep(300)
96+
trail.add("all done")
97+
98+
trail.get shouldBe Vector("exception", "catch", "all done")
99+
}
100+
}

0 commit comments

Comments
 (0)