Skip to content

Commit d94cbe2

Browse files
Kamil-LontkowskiKamil-Lontkowskiadamw
authored
RateLimiter - consider whole operation execution time (#251)
Co-authored-by: Kamil-Lontkowski <[email protected]> Co-authored-by: adamw <[email protected]>
1 parent 13cc988 commit d94cbe2

9 files changed

+494
-159
lines changed

README.md

+8
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ def computationR: Int = ???
103103
repeat(RepeatConfig.fixedRateForever(100.millis))(computationR)
104104
```
105105

106+
[Rate limit](https://ox.softwaremill.com/latest/utils/rate-limiter.html) computations:
107+
108+
```scala mdoc:compile-only
109+
supervised:
110+
val rateLimiter = RateLimiter.fixedWindowWithStartTime(2, 1.second)
111+
rateLimiter.runBlocking({ /* ... */ })
112+
```
113+
106114
Allocate a [resource](https://ox.softwaremill.com/latest/utils/resources.html) in a scope:
107115

108116
```scala mdoc:compile-only
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package ox.resilience
2+
3+
import java.util.concurrent.Semaphore
4+
import java.util.concurrent.atomic.AtomicInteger
5+
import java.util.concurrent.atomic.AtomicLong
6+
import java.util.concurrent.atomic.AtomicReference
7+
import scala.annotation.tailrec
8+
import scala.collection.immutable.Queue
9+
import scala.concurrent.duration.FiniteDuration
10+
import ox.discard
11+
12+
/** Algorithms, which take into account the entire duration of the operation.
13+
*
14+
* There is no leakyBucket algorithm implemented, which is present in [[StartTimeRateLimiterAlgorithm]], because effectively it would
15+
* result in "max number of operations currently running", which can be achieved with single semaphore.
16+
*/
17+
object DurationRateLimiterAlgorithm:
18+
/** Fixed window algorithm: allows running at most `rate` operations in consecutively segments of duration `per`. Considers whole
19+
* execution time of an operation. Operation spanning more than one window blocks permits in all windows that it spans.
20+
*/
21+
case class FixedWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
22+
private val lastUpdate = new AtomicLong(System.nanoTime())
23+
private val semaphore = new Semaphore(rate)
24+
private val runningOperations = new AtomicInteger(0)
25+
26+
def acquire(permits: Int): Unit =
27+
semaphore.acquire(permits)
28+
29+
def tryAcquire(permits: Int): Boolean =
30+
semaphore.tryAcquire(permits)
31+
32+
def getNextUpdate: Long =
33+
val waitTime = lastUpdate.get() + per.toNanos - System.nanoTime()
34+
if waitTime > 0 then waitTime else 0L
35+
36+
def update(): Unit =
37+
val now = System.nanoTime()
38+
lastUpdate.set(now)
39+
// We treat running operation in new window the same as a new operation that started in this window, so we replenish permits to: rate - operationsRunning
40+
semaphore.release(rate - semaphore.availablePermits() - runningOperations.get())
41+
end update
42+
43+
def runOperation[T](operation: => T, permits: Int): T =
44+
runningOperations.updateAndGet(_ + permits)
45+
try operation
46+
finally runningOperations.updateAndGet(_ - permits).discard
47+
48+
end FixedWindow
49+
50+
/** Sliding window algorithm: allows to run at most `rate` operations in the lapse of `per` before current time. Considers whole execution
51+
* time of an operation. Operation release permit after `per` passed since operation ended.
52+
*/
53+
case class SlidingWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
54+
// stores the timestamp and the number of permits acquired after finishing running operation
55+
private val log = new AtomicReference[Queue[(Long, Int)]](Queue[(Long, Int)]())
56+
private val semaphore = new Semaphore(rate)
57+
58+
def acquire(permits: Int): Unit =
59+
semaphore.acquire(permits)
60+
61+
def tryAcquire(permits: Int): Boolean =
62+
semaphore.tryAcquire(permits)
63+
64+
private def addTimestampToLog(permits: Int): Unit =
65+
val now = System.nanoTime()
66+
log.updateAndGet { q =>
67+
q.enqueue((now, permits))
68+
}
69+
()
70+
71+
def getNextUpdate: Long =
72+
log.get().headOption match
73+
case None =>
74+
// no logs so no need to update until `per` has passed
75+
per.toNanos
76+
case Some(record) =>
77+
// oldest log provides the new updating point
78+
val waitTime = record._1 + per.toNanos - System.nanoTime()
79+
if waitTime > 0 then waitTime else 0L
80+
end getNextUpdate
81+
82+
def runOperation[T](operation: => T, permits: Int): T =
83+
try operation
84+
// Consider end of operation as a point to release permit after `per` passes
85+
finally addTimestampToLog(permits)
86+
87+
def update(): Unit =
88+
val now = System.nanoTime()
89+
// retrieving current queue to append it later if some elements were added concurrently
90+
val q = log.getAndUpdate(_ => Queue[(Long, Int)]())
91+
// remove records older than window size
92+
val qUpdated = removeRecords(q, now)
93+
// merge old records with the ones concurrently added
94+
log.updateAndGet(qNew =>
95+
qNew.foldLeft(qUpdated) { case (queue, record) =>
96+
queue.enqueue(record)
97+
}
98+
)
99+
()
100+
end update
101+
102+
@tailrec
103+
private def removeRecords(q: Queue[(Long, Int)], now: Long): Queue[(Long, Int)] =
104+
q.dequeueOption match
105+
case None => q
106+
case Some((head, tail)) =>
107+
if head._1 + per.toNanos < now then
108+
val (_, permits) = head
109+
semaphore.release(permits)
110+
removeRecords(tail, now)
111+
else q
112+
end match
113+
end removeRecords
114+
115+
end SlidingWindow
116+
117+
end DurationRateLimiterAlgorithm

core/src/main/scala/ox/resilience/RateLimiter.scala

+55-15
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,24 @@ package ox.resilience
22

33
import scala.concurrent.duration.FiniteDuration
44
import ox.*
5-
65
import scala.annotation.tailrec
76

8-
/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. */
7+
/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. The rate limiter might
8+
* take into account the start time of the operation, or its entire duration.
9+
*/
910
class RateLimiter private (algorithm: RateLimiterAlgorithm):
1011
/** Runs the operation, blocking if the rate limit is reached, until the rate limiter is replenished. */
1112
def runBlocking[T](operation: => T): T =
1213
algorithm.acquire()
13-
operation
14+
algorithm.runOperation(operation)
1415

15-
/** Runs or drops the operation, if the rate limit is reached.
16+
/** Runs the operation or drops it, if the rate limit is reached.
1617
*
1718
* @return
18-
* `Some` if the operation has been allowed to run, `None` if the operation has been dropped.
19+
* `Some` if the operation has been run, `None` if the operation has been dropped.
1920
*/
2021
def runOrDrop[T](operation: => T): Option[T] =
21-
if algorithm.tryAcquire() then Some(operation)
22+
if algorithm.tryAcquire() then Some(algorithm.runOperation(operation))
2223
else None
2324

2425
end RateLimiter
@@ -39,32 +40,36 @@ object RateLimiter:
3940
new RateLimiter(algorithm)
4041
end apply
4142

42-
/** Creates a rate limiter using a fixed window algorithm.
43+
/** Creates a rate limiter using a fixed window algorithm. Takes into account the start time of the operation only.
4344
*
4445
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
4546
*
4647
* @param maxOperations
4748
* Maximum number of operations that are allowed to **start** within a time [[window]].
4849
* @param window
49-
* Interval of time between replenishing the rate limiter. THe rate limiter is replenished to allow up to [[maxOperations]] in the next
50+
* Interval of time between replenishing the rate limiter. The rate limiter is replenished to allow up to [[maxOperations]] in the next
5051
* time window.
52+
* @see
53+
* [[fixedWindowWithDuration]]
5154
*/
52-
def fixedWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
53-
apply(RateLimiterAlgorithm.FixedWindow(maxOperations, window))
55+
def fixedWindowWithStartTime(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
56+
apply(StartTimeRateLimiterAlgorithm.FixedWindow(maxOperations, window))
5457

55-
/** Creates a rate limiter using a sliding window algorithm.
58+
/** Creates a rate limiter using a sliding window algorithm. Takes into account the start time of the operation only.
5659
*
5760
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
5861
*
5962
* @param maxOperations
6063
* Maximum number of operations that are allowed to **start** within any [[window]] of time.
6164
* @param window
6265
* Length of the window.
66+
* @see
67+
* [[slidingWindowWithDuration]]
6368
*/
64-
def slidingWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
65-
apply(RateLimiterAlgorithm.SlidingWindow(maxOperations, window))
69+
def slidingWindowWithStartTime(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
70+
apply(StartTimeRateLimiterAlgorithm.SlidingWindow(maxOperations, window))
6671

67-
/** Rate limiter with token/leaky bucket algorithm.
72+
/** Creates a rate limiter with token/leaky bucket algorithm. Takes into account the start time of the operation only.
6873
*
6974
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
7075
*
@@ -74,5 +79,40 @@ object RateLimiter:
7479
* Interval of time between adding a single token to the bucket.
7580
*/
7681
def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration)(using Ox): RateLimiter =
77-
apply(RateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))
82+
apply(StartTimeRateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))
83+
84+
/** Creates a rate limiter with a fixed window algorithm.
85+
*
86+
* Takes into account the entire duration of the operation. That is the instant at which the operation "happens" can be anywhere between
87+
* its start and end. This ensures that the rate limit is always respected, although it might make it more restrictive.
88+
*
89+
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
90+
*
91+
* @param maxOperations
92+
* Maximum number of operations that are allowed to **run** (finishing from previous windows or start new) within a time [[window]].
93+
* @param window
94+
* Length of the window.
95+
* @see
96+
* [[fixedWindowWithStartTime]]
97+
*/
98+
def fixedWindowWithDuration(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
99+
apply(DurationRateLimiterAlgorithm.FixedWindow(maxOperations, window))
100+
101+
/** Creates a rate limiter using a sliding window algorithm.
102+
*
103+
* Takes into account the entire duration of the operation. That is the instant at which the operation "happens" can be anywhere between
104+
* its start and end. This ensures that the rate limit is always respected, although it might make it more restrictive.
105+
*
106+
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
107+
*
108+
* @param maxOperations
109+
* Maximum number of operations that are allowed to **run** (start or finishing) within any [[window]] of time.
110+
* @param window
111+
* Length of the window.
112+
* @see
113+
* [[slidingWindowWithStartTime]]
114+
*/
115+
def slidingWindowWithDuration(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
116+
apply(DurationRateLimiterAlgorithm.SlidingWindow(maxOperations, window))
117+
78118
end RateLimiter
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
11
package ox.resilience
22

3-
import scala.concurrent.duration.FiniteDuration
4-
import scala.collection.immutable.Queue
5-
import java.util.concurrent.atomic.AtomicLong
6-
import java.util.concurrent.atomic.AtomicReference
7-
import java.util.concurrent.Semaphore
8-
import scala.annotation.tailrec
9-
103
/** Determines the algorithm to use for the rate limiter */
114
trait RateLimiterAlgorithm:
125

@@ -30,113 +23,10 @@ trait RateLimiterAlgorithm:
3023
/** Returns the time in nanoseconds that needs to elapse until the next update. It should not modify internal state. */
3124
def getNextUpdate: Long
3225

33-
end RateLimiterAlgorithm
34-
35-
object RateLimiterAlgorithm:
36-
/** Fixed window algorithm: allows starting at most `rate` operations in consecutively segments of duration `per`. */
37-
case class FixedWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
38-
private val lastUpdate = new AtomicLong(System.nanoTime())
39-
private val semaphore = new Semaphore(rate)
40-
41-
def acquire(permits: Int): Unit =
42-
semaphore.acquire(permits)
43-
44-
def tryAcquire(permits: Int): Boolean =
45-
semaphore.tryAcquire(permits)
46-
47-
def getNextUpdate: Long =
48-
val waitTime = lastUpdate.get() + per.toNanos - System.nanoTime()
49-
if waitTime > 0 then waitTime else 0L
50-
51-
def update(): Unit =
52-
val now = System.nanoTime()
53-
lastUpdate.set(now)
54-
semaphore.release(rate - semaphore.availablePermits())
55-
end update
56-
57-
end FixedWindow
58-
59-
/** Sliding window algorithm: allows to start at most `rate` operations in the lapse of `per` before current time. */
60-
case class SlidingWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
61-
// stores the timestamp and the number of permits acquired after calling acquire or tryAcquire successfully
62-
private val log = new AtomicReference[Queue[(Long, Int)]](Queue[(Long, Int)]())
63-
private val semaphore = new Semaphore(rate)
64-
65-
def acquire(permits: Int): Unit =
66-
semaphore.acquire(permits)
67-
addTimestampToLog(permits)
68-
69-
def tryAcquire(permits: Int): Boolean =
70-
if semaphore.tryAcquire(permits) then
71-
addTimestampToLog(permits)
72-
true
73-
else false
74-
75-
private def addTimestampToLog(permits: Int): Unit =
76-
val now = System.nanoTime()
77-
log.updateAndGet { q =>
78-
q.enqueue((now, permits))
79-
}
80-
()
81-
82-
def getNextUpdate: Long =
83-
log.get().headOption match
84-
case None =>
85-
// no logs so no need to update until `per` has passed
86-
per.toNanos
87-
case Some(record) =>
88-
// oldest log provides the new updating point
89-
val waitTime = record._1 + per.toNanos - System.nanoTime()
90-
if waitTime > 0 then waitTime else 0L
91-
end getNextUpdate
92-
93-
def update(): Unit =
94-
val now = System.nanoTime()
95-
// retrieving current queue to append it later if some elements were added concurrently
96-
val q = log.getAndUpdate(_ => Queue[(Long, Int)]())
97-
// remove records older than window size
98-
val qUpdated = removeRecords(q, now)
99-
// merge old records with the ones concurrently added
100-
val _ = log.updateAndGet(qNew =>
101-
qNew.foldLeft(qUpdated) { case (queue, record) =>
102-
queue.enqueue(record)
103-
}
104-
)
105-
end update
106-
107-
@tailrec
108-
private def removeRecords(q: Queue[(Long, Int)], now: Long): Queue[(Long, Int)] =
109-
q.dequeueOption match
110-
case None => q
111-
case Some((head, tail)) =>
112-
if head._1 + per.toNanos < now then
113-
val (_, permits) = head
114-
semaphore.release(permits)
115-
removeRecords(tail, now)
116-
else q
117-
118-
end SlidingWindow
119-
120-
/** Token/leaky bucket algorithm It adds a token to start an new operation each `per` with a maximum number of tokens of `rate`. */
121-
case class LeakyBucket(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
122-
private val refillInterval = per.toNanos
123-
private val lastRefillTime = new AtomicLong(System.nanoTime())
124-
private val semaphore = new Semaphore(1)
125-
126-
def acquire(permits: Int): Unit =
127-
semaphore.acquire(permits)
128-
129-
def tryAcquire(permits: Int): Boolean =
130-
semaphore.tryAcquire(permits)
131-
132-
def getNextUpdate: Long =
133-
val waitTime = lastRefillTime.get() + refillInterval - System.nanoTime()
134-
if waitTime > 0 then waitTime else 0L
26+
/** Runs the operation, allowing the algorithm to take into account its duration, if needed. */
27+
final def runOperation[T](operation: => T): T = runOperation(operation, 1)
13528

136-
def update(): Unit =
137-
val now = System.nanoTime()
138-
lastRefillTime.set(now)
139-
if semaphore.availablePermits() < rate then semaphore.release()
29+
/** Runs the operation, allowing the algorithm to take into account its duration, if needed. */
30+
def runOperation[T](operation: => T, permits: Int): T
14031

141-
end LeakyBucket
14232
end RateLimiterAlgorithm

0 commit comments

Comments
 (0)