Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collectPar, filterPar, foreachPar for collections added #54

Merged
merged 6 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,66 @@ val result: (Int, String) = par(computation1)(computation2)

If one of the computations fails, the other is interrupted, and `par` waits until both branches complete.

## Parallelize collection transformation
## Parallelize collection operations

### mapPar

```scala
import ox.mapPar
import ox.syntax.mapPar

val input: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

val result: List[Int] = mapPar(input)(4)(_ + 1)
val result: List[Int] = input.mapPar(4)(_ + 1)
// (2, 3, 4, 5, 6, 7, 8, 9, 10)
```

If any transformation fails, others are interrupted and `mapPar` rethrows exception that was
thrown by the transformation. Parallelism
limits how many concurrent forks are going to process the collection.

### foreachPar

```scala
import ox.syntax.foreachPar

val input: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

input.mapPar(4)(i => println())
// Prints each element of the list, might be in any order
```

Similar to `mapPar` but doesn't return anything.

### filterPar

```scala
import ox.syntax.filterPar

val input: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

val result:List[Int] = input.filterPar(4)(_ % 2 == 0)
// (2, 4, 6, 8, 10)
```

Filters collection in parallel using provided predicate. If any predicate fails, rethrows the exception
and other forks calculating predicates are interrupted.

### collectPar

```scala
import ox.syntax.collectPar

val input: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

val result: List[Int] = input.collectPar(4) {
case i if i % 2 == 0 => i + 1
}
// (3, 5, 7, 9, 11)
```

Similar to `mapPar` but only applies transformation to elements for which
the partial function is defined. Other elements are skipped.

## Race two computations

```scala
Expand Down
32 changes: 32 additions & 0 deletions core/src/main/scala/ox/collectPar.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package ox

import scala.collection.IterableFactory

/** Runs partial function in parallel on each element of `iterable` for which the partiel function is defined.
* If function is not defined for an element such element is skipped.
* Using not more than `parallelism` forks concurrently.
*
* @tparam I type of elements of `iterable`
* @tparam O type of elements of result
* @tparam C type of `iterable`, must be a subtype of `Iterable`
*
* @param parallelism maximum number of concurrent forks
* @param iterable collection to transform
* @param pf partial function to apply to those elements of `iterable` for which it is defined
*
* @return collection of results of applying `pf` to elements of `iterable` for which it is defined. The returned
* collection is of the same type as `iterable`
*/
def collectPar[I, O, C[E] <: Iterable[E]](parallelism: Int)(iterable: => C[I])(pf: PartialFunction[I, O]): C[O] =

def nonPartialOperation(elem: I): Option[O] =
if pf.isDefinedAt(elem) then
Some(pf(elem))
else
None

def handleOutputs(outputs: Seq[Option[O]]): C[O] =
outputs.collect { case Some(output) => output }.to(iterable.iterableFactory.asInstanceOf[IterableFactory[C]])

commonPar(parallelism, iterable, nonPartialOperation, handleOutputs)

21 changes: 21 additions & 0 deletions core/src/main/scala/ox/common/commonPar.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package ox

import java.util.concurrent.Semaphore

private[ox] def commonPar[I, O, C[E] <: Iterable[E], FO](parallelism: Int, iterable: => C[I], transform: I => O, handleOutputs: Seq[O] => FO): FO =
val s = Semaphore(parallelism)

supervised {
val forks = iterable.map { elem =>
s.acquire()
fork {
val o = transform(elem)
s.release()
o
}
}
val outputs = forks.toSeq.map(f => f.join())
handleOutputs(outputs)
}


27 changes: 27 additions & 0 deletions core/src/main/scala/ox/filterPar.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package ox

import scala.collection.IterableFactory

/** Runs predicate in parallel on each element of `iterable`. Elements for which predicate returns `true` are returned
* in the same order as in `iterable`. Elements for which predicate returns `false` are skipped.
* Using not more than `parallelism` forks concurrently.
*
* @tparam I type of elements in `iterable`
* @tparam C type of `iterable`, must be a subtype of `Iterable`
*
* @param parallelism maximum number of concurrent forks
* @param iterable collection to filter
* @param predicate predicate to run on each element of `iterable`
*
* @return filtered collection
*/
def filterPar[I, C[E] <: Iterable[E]](parallelism: Int)(iterable: => C[I])(predicate: I => Boolean): C[I] =

def addCalculatedFilter(elem: I): (Boolean, I) =
(predicate(elem), elem)

def handleOutputs(outputs: Seq[(Boolean, I)]): C[I] =
outputs.collect { case (true, elem) => elem }.to(iterable.iterableFactory.asInstanceOf[IterableFactory[C]])

commonPar(parallelism, iterable, addCalculatedFilter, handleOutputs)

20 changes: 20 additions & 0 deletions core/src/main/scala/ox/foreachPar.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package ox

import scala.annotation.unused

/**
* Parallelize a foreach operation. Runs the operation on each element of the iterable in parallel.
* Using not more than `parallelism` forks concurrently.
*
* @tparam I the type of the elements in the iterable
* @tparam C the type of the iterable, must be a subtype of Iterable[I]
*
* @param parallelism the number of threads to use
* @param iterable the collection to iterate over
* @param operation the operation to perform on each element
*/
def foreachPar[I, C <: Iterable[I]](parallelism: Int)(iterable: => C)(operation: I => Any): Unit =
def handleOutputs(@unused outputs: Seq[_]): Unit = ()

commonPar(parallelism, iterable, operation, handleOutputs)

24 changes: 11 additions & 13 deletions core/src/main/scala/ox/mapPar.scala
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
package ox

import java.util.concurrent.Semaphore
import scala.collection.IterableFactory

/** Runs parallel transformations on `iterable`. Using not more than `parallelism` forks concurrently.
*
* @tparam I type of elements in `iterable`
* @tparam O type of elements in result
* @tparam C type of `iterable`, must be a subtype of `Iterable`
*
* @param parallelism maximum number of concurrent forks
* @param iterable collection to transform
* @param transform transformation to apply to each element of `iterable`
*
* @return transformed collection of the same type as input one
*/
def mapPar[I, O, C[E] <: Iterable[E]](parallelism: Int)(iterable: => C[I])(transform: I => O): C[O] =
val s = Semaphore(parallelism)

supervised {
val forks = iterable.map { elem =>
s.acquire()
fork {
val o = transform(elem)
s.release()
o
}
}
forks.toSeq.map(f => f.join()).to(iterable.iterableFactory.asInstanceOf[IterableFactory[C]])
}
def handleOutputs(outputs: Seq[O]): C[O] =
outputs.to(iterable.iterableFactory.asInstanceOf[IterableFactory[C]])

commonPar(parallelism, iterable, transform, handleOutputs)

3 changes: 3 additions & 0 deletions core/src/main/scala/ox/syntax.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ object syntax:

extension [I, C[E] <: Iterable[E]](f: => C[I])
def mapPar[O](parallelism: Int)(transform: I => O) = ox.mapPar(parallelism)(f)(transform)
def collectPar[O](parallelism: Int)(pf: PartialFunction[I, O]) = ox.collectPar(parallelism)(f)(pf)
def foreachPar(parallelism: Int)(operation: I => Any) = ox.foreachPar(parallelism)(f)(operation)
def filterPar(parallelism: Int)(predicate: I => Boolean) = ox.filterPar(parallelism)(f)(predicate)
83 changes: 83 additions & 0 deletions core/src/test/scala/ox/CollectParTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package ox

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.syntax.collectPar
import ox.util.{MaxCounter, Trail}

import scala.List
import scala.collection.IterableFactory
import scala.collection.immutable.Iterable

class CollectParTest extends AnyFlatSpec with Matchers {
"collectPar" should "output the same type as input" in {
val input = List(1, 2, 3)
val result = input.collectPar(1)(identity)
result shouldBe a[List[_]]
}

it should "run computations in parallel" in {
val InputElements = 17
val TransformationMillis: Long = 100

val input = (0 to InputElements)
val pf: PartialFunction[Int, Int] = {
case i if i % 2 == 0 => i
}

val start = System.currentTimeMillis()
val result = input.to(Iterable).collectPar(5)(pf)
val end = System.currentTimeMillis()

result.toList should contain theSameElementsInOrderAs List(0, 2, 4, 6, 8, 10, 12, 14, 16)
(end - start) should be < (InputElements * TransformationMillis)
}

it should "run not more computations than limit" in {
val Parallelism = 5

val input = (1 to 158)

val maxCounter = new MaxCounter()

def transformation(i: Int) = {
maxCounter.increment()
Thread.sleep(10)
maxCounter.decrement()
}

input.to(Iterable).collectPar(Parallelism)(transformation)

maxCounter.max should be <= Parallelism
}

it should "interrupt other computations in one fails" in {
val InputElements = 18
val TransformationMillis: Long = 100
val trail = Trail()

val input = (0 to InputElements)

def transformation(i: Int) = {
if (i == 4) {
trail.add("exception")
throw new Exception("boom")
} else {
Thread.sleep(TransformationMillis)
trail.add("transformation")
i + 1
}
}

try {
input.to(Iterable).collectPar(5)(transformation)
} catch {
case e: Exception if e.getMessage == "boom" => trail.add("catch")
}

Thread.sleep(300)
trail.add("all done")

trail.get shouldBe Vector("exception", "catch", "all done")
}
}
85 changes: 85 additions & 0 deletions core/src/test/scala/ox/FilterParTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package ox

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.util.{MaxCounter, Trail}
import ox.syntax.filterPar

import scala.List
import scala.collection.IterableFactory
import scala.collection.immutable.Iterable

class FilterParTest extends AnyFlatSpec with Matchers {
"filterPar" should "output the same type as input" in {
val input = List(1, 2, 3)
val result = input.filterPar(1)(_ => true)
result shouldBe a[List[_]]
}

it should "run computations in parallel" in {
val InputElements = 17
val TransformationMillis: Long = 100

val input = (0 to InputElements)
def predicate(i: Int) = {
Thread.sleep(TransformationMillis)
i % 2 == 0
}

val start = System.currentTimeMillis()
val result = input.to(Iterable).filterPar(5)(predicate)
val end = System.currentTimeMillis()

result.toList should contain theSameElementsInOrderAs List(0, 2, 4, 6, 8, 10, 12, 14, 16)
(end - start) should be < (InputElements * TransformationMillis)
}

it should "run not more computations than limit" in {
val Parallelism = 5

val input = (1 to 158)

val maxCounter = new MaxCounter()

def predicate(i: Int) = {
maxCounter.increment()
Thread.sleep(10)
maxCounter.decrement()
true
}

input.to(Iterable).filterPar(Parallelism)(predicate)

maxCounter.max should be <= Parallelism
}

it should "interrupt other computations in one fails" in {
val InputElements = 18
val TransformationMillis: Long = 100
val trail = Trail()

val input = (0 to InputElements)

def predicate(i: Int) = {
if (i == 4) {
trail.add("exception")
throw new Exception("boom")
} else {
Thread.sleep(TransformationMillis)
trail.add("transformation")
true
}
}

try {
input.to(Iterable).filterPar(5)(predicate)
} catch {
case e: Exception if e.getMessage == "boom" => trail.add("catch")
}

Thread.sleep(300)
trail.add("all done")

trail.get shouldBe Vector("exception", "catch", "all done")
}
}
Loading