Skip to content

Commit 02f5198

Browse files
authored
Merge pull request #15 from softwaremill/source-stateful-map
Add Source.mapStateful and Source.mapStatefulConcat combinators
2 parents 12e80dd + 3db1ded commit 02f5198

File tree

3 files changed

+247
-0
lines changed

3 files changed

+247
-0
lines changed

core/src/main/scala/ox/channels/SourceOps.scala

+108
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ox.*
55
import java.util.concurrent.{CountDownLatch, Semaphore}
66
import scala.collection.mutable
77
import scala.concurrent.duration.FiniteDuration
8+
import scala.collection.IterableOnce
89

910
trait SourceOps[+T] { this: Source[T] =>
1011
// view ops (lazy)
@@ -311,6 +312,113 @@ trait SourceOps[+T] { this: Source[T] =>
311312
def drain(): Unit = foreach(_ => ())
312313

313314
def applied[U](f: Source[T] => U): U = f(this)
315+
316+
/** Applies the given mapping function `f`, using additional state, to each element received from this source, and sends the results to
317+
* the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel once this
318+
* source is done.
319+
*
320+
* The `initializeState` function is called once when `statefulMap` is called.
321+
*
322+
* The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the
323+
* returned channel, while an empty value will be ignored.
324+
*
325+
* @param initializeState
326+
* A function that initializes the state.
327+
* @param f
328+
* A function that transforms the element from this source and the state into a pair of the next state and the result which is sent
329+
* sent to the returned channel.
330+
* @param onComplete
331+
* A function that transforms the final state into an optional element sent to the returned channel. By default the final state is
332+
* ignored.
333+
* @return
334+
* A source to which the results of applying `f` to the elements from this source would be sent.
335+
* @example
336+
* {{{
337+
* scala>
338+
* import ox.*
339+
* import ox.channels.Source
340+
*
341+
* scoped {
342+
* val s = Source.fromValues(1, 2, 3, 4, 5)
343+
* s.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply)
344+
* }
345+
*
346+
* scala> val res0: List[Int] = List(0, 1, 3, 6, 10, 15)
347+
* }}}
348+
*/
349+
def mapStateful[S, U >: T](
350+
initializeState: () => S
351+
)(f: (S, T) => (S, U), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] =
352+
def resultToSome(s: S, t: T) =
353+
val (newState, result) = f(s, t)
354+
(newState, Some(result))
355+
356+
mapStatefulConcat(initializeState)(resultToSome, onComplete)
357+
358+
/** Applies the given mapping function `f`, using additional state, to each element received from this source, and sends the results one
359+
* by one to the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel
360+
* once this source is done.
361+
*
362+
* The `initializeState` function is called once when `statefulMap` is called.
363+
*
364+
* The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the
365+
* returned channel, while an empty value will be ignored.
366+
*
367+
* @param initializeState
368+
* A function that initializes the state.
369+
* @param f
370+
* A function that transforms the element from this source and the state into a pair of the next state and a
371+
* [[scala.collection.IterableOnce]] of results which are sent one by one to the returned channel. If the result of `f` is empty,
372+
* nothing is sent to the returned channel.
373+
* @param onComplete
374+
* A function that transforms the final state into an optional element sent to the returned channel. By default the final state is
375+
* ignored.
376+
* @return
377+
* A source to which the results of applying `f` to the elements from this source would be sent.
378+
* @example
379+
* {{{
380+
* scala>
381+
* import ox.*
382+
* import ox.channels.Source
383+
*
384+
* scoped {
385+
* val s = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5)
386+
* // deduplicate the values
387+
* s.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e)))
388+
* }
389+
*
390+
* scala> val res0: List[Int] = List(1, 2, 3, 4, 5)
391+
* }}}
392+
*/
393+
def mapStatefulConcat[S, U >: T](
394+
initializeState: () => S
395+
)(f: (S, T) => (S, IterableOnce[U]), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] =
396+
val c = StageCapacity.newChannel[U]
397+
forkDaemon {
398+
var state = initializeState()
399+
repeatWhile {
400+
receive() match
401+
case ChannelClosed.Done =>
402+
try
403+
onComplete(state).foreach(c.send)
404+
c.done()
405+
catch case t: Throwable => c.error(t)
406+
false
407+
case ChannelClosed.Error(r) =>
408+
c.error(r)
409+
false
410+
case t: T @unchecked =>
411+
try
412+
val (nextState, result) = f(state, t)
413+
state = nextState
414+
result.iterator.map(c.send).forall(_.isValue)
415+
catch
416+
case t: Throwable =>
417+
c.error(t)
418+
false
419+
}
420+
}
421+
c
314422
}
315423

316424
trait SourceCompanionOps:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package ox.channels
2+
3+
import org.scalatest.flatspec.AnyFlatSpec
4+
import org.scalatest.matchers.should.Matchers
5+
import ox.*
6+
7+
class SourceOpsMapStatefulConcatTest extends AnyFlatSpec with Matchers {
8+
9+
behavior of "Source.mapStatefulConcat"
10+
11+
it should "deduplicate" in scoped {
12+
// given
13+
val c = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5)
14+
15+
// when
16+
val s = c.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e)))
17+
18+
// then
19+
s.toList shouldBe List(1, 2, 3, 4, 5)
20+
}
21+
22+
it should "count consecutive" in scoped {
23+
// given
24+
val c = Source.fromValues("apple", "apple", "apple", "banana", "orange", "orange", "apple")
25+
26+
// when
27+
val s = c.mapStatefulConcat(() => (Option.empty[String], 0))(
28+
{ case ((previous, count), e) =>
29+
previous match
30+
case None => ((Some(e), 1), None)
31+
case Some(`e`) => ((previous, count + 1), None)
32+
case Some(_) => ((Some(e), 1), previous.map((_, count)))
33+
},
34+
{ case (previous, count) => previous.map((_, count)) }
35+
)
36+
37+
// then
38+
s.toList shouldBe List(
39+
("apple", 3),
40+
("banana", 1),
41+
("orange", 2),
42+
("apple", 1)
43+
)
44+
}
45+
46+
it should "propagate errors in the mapping function" in scoped {
47+
// given
48+
val c = Source.fromValues("a", "b", "c")
49+
50+
// when
51+
val s = c.mapStatefulConcat(() => 0) { (index, element) =>
52+
if (index < 2) (index + 1, Some(element))
53+
else throw new RuntimeException("boom")
54+
}
55+
56+
// then
57+
s.receive() shouldBe "a"
58+
s.receive() shouldBe "b"
59+
s.receive() should matchPattern {
60+
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
61+
}
62+
}
63+
64+
it should "propagate errors in the completion callback" in scoped {
65+
// given
66+
val c = Source.fromValues("a", "b", "c")
67+
68+
// when
69+
val s = c.mapStatefulConcat(() => 0)((index, element) => (index + 1, Some(element)), _ => throw new RuntimeException("boom"))
70+
71+
// then
72+
s.receive() shouldBe "a"
73+
s.receive() shouldBe "b"
74+
s.receive() shouldBe "c"
75+
s.receive() should matchPattern {
76+
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
77+
}
78+
}
79+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package ox.channels
2+
3+
import org.scalatest.flatspec.AnyFlatSpec
4+
import org.scalatest.matchers.should.Matchers
5+
import ox.*
6+
7+
class SourceOpsMapStatefulTest extends AnyFlatSpec with Matchers {
8+
9+
behavior of "Source.mapStateful"
10+
11+
it should "zip with index" in scoped {
12+
val c = Source.fromValues("a", "b", "c")
13+
14+
val s = c.mapStateful(() => 0)((index, element) => (index + 1, (element, index)))
15+
16+
s.toList shouldBe List(("a", 0), ("b", 1), ("c", 2))
17+
}
18+
19+
it should "calculate a running total" in scoped {
20+
val c = Source.fromValues(1, 2, 3, 4, 5)
21+
22+
val s = c.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply)
23+
24+
s.toList shouldBe List(0, 1, 3, 6, 10, 15)
25+
}
26+
27+
it should "propagate errors in the mapping function" in scoped {
28+
// given
29+
val c = Source.fromValues("a", "b", "c")
30+
31+
// when
32+
val s = c.mapStateful(() => 0) { (index, element) =>
33+
if (index < 2) (index + 1, element)
34+
else throw new RuntimeException("boom")
35+
}
36+
37+
// then
38+
s.receive() shouldBe "a"
39+
s.receive() shouldBe "b"
40+
s.receive() should matchPattern {
41+
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
42+
}
43+
}
44+
45+
it should "propagate errors in the completion callback" in scoped {
46+
// given
47+
val c = Source.fromValues("a", "b", "c")
48+
49+
// when
50+
val s = c.mapStateful(() => 0)((index, element) => (index + 1, element), _ => throw new RuntimeException("boom"))
51+
52+
// then
53+
s.receive() shouldBe "a"
54+
s.receive() shouldBe "b"
55+
s.receive() shouldBe "c"
56+
s.receive() should matchPattern {
57+
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
58+
}
59+
}
60+
}

0 commit comments

Comments
 (0)