Skip to content

Commit b0bb1bc

Browse files
authored
Add Source.mapConcat operator (#99)
1 parent 1844aa4 commit b0bb1bc

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

core/src/main/scala/ox/channels/SourceOps.scala

+47-1
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,52 @@ trait SourceOps[+T] { outer: Source[T] =>
538538
}
539539
c
540540

541+
/** Applies the given mapping function `f`, to each element received from this source, transforming it into an Iterable of results, then
542+
* sends the results one by one to the returned channel. Can be used to unfold incoming sequences of elements into single elements.
543+
*
544+
* @param f
545+
* A function that transforms the element from this source into a pair of the next state into an [[scala.collection.IterableOnce]] of
546+
* results which are sent one by one to the returned channel. If the result of `f` is empty, nothing is sent to the returned channel.
547+
* @return
548+
* A source to which the results of applying `f` to the elements from this source would be sent.
549+
* @example
550+
* {{{
551+
* scala>
552+
* import ox.*
553+
* import ox.channels.Source
554+
*
555+
* supervised {
556+
* val s = Source.fromValues(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9))
557+
* s.mapConcat(identity)
558+
* }
559+
*
560+
* scala> val res0: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9)
561+
* }}}
562+
*/
563+
def mapConcat[U](f: T => IterableOnce[U])(using Ox, StageCapacity): Source[U] =
564+
val c = StageCapacity.newChannel[U]
565+
fork {
566+
repeatWhile {
567+
receiveSafe() match
568+
case ChannelClosed.Done =>
569+
c.doneSafe()
570+
false
571+
case ChannelClosed.Error(r) =>
572+
c.errorSafe(r)
573+
false
574+
case t: T @unchecked =>
575+
try
576+
val results: IterableOnce[U] = f(t)
577+
results.iterator.foreach(c.send)
578+
true
579+
catch
580+
case t: Throwable =>
581+
c.errorSafe(t)
582+
false
583+
}
584+
}
585+
c
586+
541587
/** Returns the first element from this source wrapped in [[Some]] or [[None]] when this source is empty. Note that `headOption` is not an
542588
* idempotent operation on source as it receives elements from it.
543589
*
@@ -565,7 +611,7 @@ trait SourceOps[+T] { outer: Source[T] =>
565611
case e: ChannelClosed.Error => throw e.toThrowable
566612
case t: T @unchecked => Some(t)
567613
}
568-
614+
569615
/** Sends elements to the returned channel limiting the throughput to specific number of elements (evenly spaced) per time unit. Note that
570616
* the element's `receive()` time is included in the resulting throughput. For instance having `throttle(1, 1.second)` and `receive()`
571617
* taking `Xms` means that resulting channel will receive elements every `1s + Xms` time. Throttling is not applied to the empty source.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package ox.channels
2+
3+
import org.scalatest.flatspec.AnyFlatSpec
4+
import org.scalatest.matchers.should.Matchers
5+
import ox.*
6+
7+
class SourceOpsMapConcatTest extends AnyFlatSpec with Matchers {
8+
9+
behavior of "Source.mapConcat"
10+
11+
it should "unfold iterables" in supervised {
12+
val c = Source.fromValues(List("a", "b", "c"), List("d", "e"), List("f"))
13+
val s = c.mapConcat(identity)
14+
s.toList shouldBe List("a", "b", "c", "d", "e", "f")
15+
}
16+
17+
it should "transform elements" in supervised {
18+
val c = Source.fromValues("ab", "cd")
19+
val s = c.mapConcat { str => str.toList }
20+
21+
s.toList shouldBe List('a', 'b', 'c', 'd')
22+
}
23+
24+
it should "handle empty lists" in supervised {
25+
val c = Source.fromValues(List.empty, List("a"), List.empty, List("b", "c"))
26+
val s = c.mapConcat(identity)
27+
28+
s.toList shouldBe List("a", "b", "c")
29+
}
30+
31+
it should "propagate errors in the mapping function" in supervised {
32+
// given
33+
given StageCapacity = StageCapacity(0) // so that the error isn't created too early
34+
val c = Source.fromValues(List("a"), List("b", "c"), List("error here"))
35+
36+
// when
37+
val s = c.mapConcat { element =>
38+
if (element != List("error here"))
39+
element
40+
else throw new RuntimeException("boom")
41+
}
42+
43+
// then
44+
s.receive() shouldBe "a"
45+
s.receive() shouldBe "b"
46+
s.receive() shouldBe "c"
47+
s.receiveSafe() should matchPattern {
48+
case ChannelClosed.Error(reason) if reason.getMessage == "boom" =>
49+
}
50+
}
51+
}

0 commit comments

Comments
 (0)