Skip to content

Commit 775f8db

Browse files
feat: implement fold operator
The `fold` operation returns combined value retrieved from running function `f` on all source elements in a cumulative manner where result of the previous call is used as an input value to the next e.g.: Source.empty[Int].fold(0)((acc, n) => acc + n) // 0 Source.fromValues(2, 3).fold(5)((acc, n) => acc - n) // 0 Note that in case when `receive()` operation fails then ChannelClosedException.Error exception is thrown. Wheres in case when function `f` throws then this exception is propagated up to the caller.
1 parent 4c456aa commit 775f8db

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

core/src/main/scala/ox/channels/SourceOps.scala

+36
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,42 @@ trait SourceOps[+T] { this: Source[T] =>
663663
* }}}
664664
*/
665665
def last(): T = lastOption().getOrElse(throw new NoSuchElementException("cannot obtain last element from an empty source"))
666+
667+
/** Uses `zero` as the current value and applies function `f` on it and a value received from this source. The returned value is used as
668+
* the next current value and `f` is applied again with the value received from a source. The operation is repeated until the source is
669+
* drained.
670+
*
671+
* @param zero
672+
* An initial value to be used as the first argument to function `f` call.
673+
* @param f
674+
* A binary function (a function that takes two arguments) that is applied to the current value and value received from a source.
675+
* @return
676+
* Combined value retrieved from running function `f` on all source elements in a cumulative manner where result of the previous call
677+
* is used as an input value to the next.
678+
* @throws ChannelClosedException.Error
679+
* When receiving an element from this source fails.
680+
* @throws exception
681+
* When function `f` throws an `exception` then it is propagated up to the caller.
682+
* @example
683+
* {{{
684+
* import ox.*
685+
* import ox.channels.Source
686+
*
687+
* supervised {
688+
* Source.empty[Int].fold(0)((acc, n) => acc + n) // 0
689+
* Source.fromValues(2, 3).fold(5)((acc, n) => acc - n) // 0
690+
* }
691+
* }}}
692+
*/
693+
def fold[U](zero: U)(f: (U, T) => U): U =
694+
var current = zero
695+
repeatWhile {
696+
receive() match
697+
case ChannelClosed.Done => false
698+
case e: ChannelClosed.Error => throw e.toThrowable
699+
case t: T @unchecked => current = f(current, t); true
700+
}
701+
current
666702
}
667703

668704
trait SourceCompanionOps:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package ox.channels
2+
3+
import org.scalatest.flatspec.AnyFlatSpec
4+
import org.scalatest.matchers.should.Matchers
5+
import ox.*
6+
7+
class SourceOpsFoldTest extends AnyFlatSpec with Matchers {
8+
behavior of "Source.fold"
9+
10+
it should "throw ChannelClosedException.Error with exception and message that was thrown during retrieval" in supervised {
11+
the[ChannelClosedException.Error] thrownBy {
12+
Source
13+
.failed[Int](new RuntimeException("source is broken"))
14+
.fold(0)((acc, n) => acc + n)
15+
} should have message "java.lang.RuntimeException: source is broken"
16+
}
17+
18+
it should "throw ChannelClosedException.Error for source failed without exception" in supervised {
19+
the[ChannelClosedException.Error] thrownBy {
20+
Source
21+
.failedWithoutReason[Int]()
22+
.fold(0)((acc, n) => acc + n)
23+
}
24+
}
25+
26+
it should "throw exception thrown in `f` when `f` throws" in supervised {
27+
the[RuntimeException] thrownBy {
28+
Source
29+
.fromValues(1)
30+
.fold(0)((_, _) => throw new RuntimeException("Function `f` is broken"))
31+
} should have message "Function `f` is broken"
32+
}
33+
34+
it should "return `zero` value from fold on the empty source" in supervised {
35+
Source.empty[Int].fold(0)((acc, n) => acc + n) shouldBe 0
36+
}
37+
38+
it should "return fold on non-empty source" in supervised {
39+
Source.fromValues(1, 2).fold(0)((acc, n) => acc + n) shouldBe 3
40+
}
41+
42+
it should "drain the source" in supervised {
43+
val s = Source.fromValues(1)
44+
s.fold(0)((acc, n) => acc + n) shouldBe 1
45+
s.receive() shouldBe ChannelClosed.Done
46+
}
47+
}

0 commit comments

Comments
 (0)