Skip to content

Commit 13f8b50

Browse files
feat: implement throttle function
Sends elements to the returned channel limiting the throughput to specific number of elements (evenly spaced) per time unit. Note that the element's `receive()` time is included in the resulting throughput. For instance having `throttle(1, 1.second)` and `receive()` taking `Xms` means that resulting channel will receive elements every `1s + Xms` time. Throttling is not applied to the empty source. Examples: Source.empty[Int].throttle(1, 1.second).toList // List() returned without throttling Source.fromValues(1, 2).throttle(1, 1.second).toList // List(1, 2) returned after 2 seconds Note that implementation relies on `Thread.sleep` that is according to [1] project Loom compatible. [1] https://softwaremill.com/what-is-blocking-in-loom/
1 parent e28a268 commit 13f8b50

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

core/src/main/scala/ox/channels/SourceOps.scala

+40
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,46 @@ trait SourceOps[+T] { this: Source[T] =>
565565
case ChannelClosed.Error(r) => throw r.getOrElse(new NoSuchElementException("getting head failed"))
566566
case t: T @unchecked => t
567567
}
568+
569+
/** Sends elements to the returned channel limiting the throughput to specific number of elements (evenly spaced) per time unit. Note that
570+
* the element's `receive()` time is included in the resulting throughput. For instance having `throttle(1, 1.second)` and `receive()`
571+
* taking `Xms` means that resulting channel will receive elements every `1s + Xms` time. Throttling is not applied to the empty source.
572+
*
573+
* @param elements
574+
* Number of elements to be emitted. Must be greater than 0.
575+
* @param per
576+
* Per time unit. Must be greater or equal to 1 ms.
577+
* @return
578+
* A source that emits at most `elements` `per` time unit.
579+
* @example
580+
* {{{
581+
* import ox.*
582+
* import ox.channels.Source
583+
*
584+
* import scala.concurrent.duration.*
585+
*
586+
* scoped {
587+
* Source.empty[Int].throttle(1, 1.second).toList // List() returned without throttling
588+
* Source.fromValues(1, 2).throttle(1, 1.second).toList // List(1, 2) returned after 2 seconds
589+
* }
590+
* }}}
591+
*/
592+
def throttle(elements: Int, per: FiniteDuration)(using Ox, StageCapacity): Source[T] =
593+
require(elements > 0, "elements must be > 0")
594+
require(per.toMillis > 0, "per time must be >= 1 ms")
595+
596+
val c = StageCapacity.newChannel[T]
597+
val emitEveryMillis = per.toMillis / elements
598+
599+
forkDaemon {
600+
repeatWhile {
601+
receive() match
602+
case ChannelClosed.Done => c.done(); false
603+
case ChannelClosed.Error(r) => c.error(r); false
604+
case t: T @unchecked => Thread.sleep(emitEveryMillis); c.send(t); true
605+
}
606+
}
607+
c
568608
}
569609

570610
trait SourceCompanionOps:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package ox.channels
2+
3+
import org.scalatest.flatspec.AnyFlatSpec
4+
import org.scalatest.matchers.should.Matchers
5+
import ox.*
6+
7+
import scala.concurrent.duration.*
8+
9+
class SourceOpsThrottleTest extends AnyFlatSpec with Matchers {
10+
behavior of "Source.throttle"
11+
12+
it should "not throttle the empty source" in supervised {
13+
val s = Source.empty[Int]
14+
val (result, executionTime) = measure { s.throttle(1, 1.second).toList }
15+
result shouldBe List.empty
16+
executionTime.toMillis should be < 1.second.toMillis
17+
}
18+
19+
it should "throttle to specified elements per time units" in supervised {
20+
val s = Source.fromValues(1, 2)
21+
val (result, executionTime) = measure { s.throttle(1, 50.millis).toList }
22+
result shouldBe List(1, 2)
23+
executionTime.toMillis should (be >= 100L and be <= 150L)
24+
}
25+
26+
it should "fail to throttle when elements <= 0" in supervised {
27+
val s = Source.empty[Int]
28+
the[IllegalArgumentException] thrownBy {
29+
s.throttle(-1, 50.millis)
30+
} should have message "requirement failed: elements must be > 0"
31+
}
32+
33+
it should "fail to throttle when per lower than 1ms" in supervised {
34+
val s = Source.empty[Int]
35+
the[IllegalArgumentException] thrownBy {
36+
s.throttle(1, 50.nanos)
37+
} should have message "requirement failed: per time must be >= 1 ms"
38+
}
39+
40+
private def measure[T](f: => T): (T, Duration) =
41+
val before = System.currentTimeMillis()
42+
val result = f
43+
val after = System.currentTimeMillis();
44+
(result, (after - before).millis)
45+
}

0 commit comments

Comments
 (0)