Skip to content

Commit bd6cecd

Browse files
authored
Add Source.repeatEval, improve exception handling in sources (#96)
1 parent 7d7e0ff commit bd6cecd

File tree

4 files changed

+56
-21
lines changed

4 files changed

+56
-21
lines changed

core/src/main/scala/ox/channels/SourceCompanionOps.scala

+23-4
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,31 @@ trait SourceCompanionOps:
125125
}
126126
c
127127

128-
def repeat[T](element: T = ())(using Ox, StageCapacity): Source[T] =
128+
/** Creates a channel, to which the given `element` is sent repeatedly.
129+
*
130+
* @param element
131+
* The element to send
132+
* @return
133+
* A source to which the given element is sent repeatedly.
134+
*/
135+
def repeat[T](element: T = ())(using Ox, StageCapacity): Source[T] = repeatEval(element)
136+
137+
/** Creates a channel, to which the result of evaluating `f` is sent repeatedly. As the parameter is passed by-name, the evaluation is
138+
* deferred until the element is sent, and happens multiple times.
139+
*
140+
* @param f
141+
* The code block, computing the element to send
142+
* @return
143+
* A source to which the result of evaluating `f` is sent repeatedly.
144+
*/
145+
def repeatEval[T](f: => T)(using Ox, StageCapacity): Source[T] =
129146
val c = StageCapacity.newChannel[T]
130147
fork {
131-
forever {
132-
c.sendSafe(element)
133-
}
148+
try
149+
forever {
150+
c.sendSafe(f)
151+
}
152+
catch case t: Throwable => c.errorSafe(t)
134153
}
135154
c
136155

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package ox.channels
2+
3+
import org.scalatest.flatspec.AnyFlatSpec
4+
import org.scalatest.matchers.should.Matchers
5+
import ox.supervised
6+
import ox.channels.Source
7+
8+
class SourceOpsRepeatEvalTest extends AnyFlatSpec with Matchers {
9+
behavior of "SourceOps.repeatEval"
10+
11+
it should "evaluate the element before each send" in supervised {
12+
var i = 0
13+
val s = Source.repeatEval {
14+
i += 1
15+
i
16+
}
17+
s.take(3).toList shouldBe List(1, 2, 3)
18+
}
19+
}

kafka/src/main/scala/ox/kafka/KafkaConsumerActor.scala

+11-12
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import ox.*
77
import ox.channels.*
88

99
import scala.jdk.CollectionConverters.*
10-
import scala.util.control.NonFatal
1110

1211
object KafkaConsumerActor:
1312
private val logger = LoggerFactory.getLogger(classOf[KafkaConsumerActor.type])
@@ -30,30 +29,30 @@ object KafkaConsumerActor:
3029
consumer.subscribe(topics.asJava)
3130
true
3231
catch
33-
case NonFatal(e) =>
34-
logger.error(s"Exception when subscribing to $topics", e)
35-
c.errorSafe(e)
32+
case t: Throwable =>
33+
logger.error(s"Exception when subscribing to $topics", t)
34+
c.errorSafe(t)
3635
false
3736
case KafkaConsumerRequest.Poll(results) =>
3837
try
3938
results.send(consumer.poll(java.time.Duration.ofMillis(100)))
4039
true
4140
catch
42-
case NonFatal(e) =>
43-
logger.error("Exception when polling for records in Kafka", e)
44-
results.errorSafe(e)
45-
c.errorSafe(e)
41+
case t: Throwable =>
42+
logger.error("Exception when polling for records in Kafka", t)
43+
results.errorSafe(t)
44+
c.errorSafe(t)
4645
false
4746
case KafkaConsumerRequest.Commit(offsets, result) =>
4847
try
4948
consumer.commitSync(offsets.view.mapValues(o => new OffsetAndMetadata(o + 1)).toMap.asJava)
5049
result.sendSafe(())
5150
true
5251
catch
53-
case NonFatal(e) =>
54-
logger.error("Exception when committing offsets", e)
55-
result.errorSafe(e)
56-
c.errorSafe(e)
52+
case t: Throwable =>
53+
logger.error("Exception when committing offsets", t)
54+
result.errorSafe(t)
55+
c.errorSafe(t)
5756
false
5857
}
5958
finally

kafka/src/main/scala/ox/kafka/KafkaSource.scala

+3-5
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import org.slf4j.LoggerFactory
55
import ox.*
66
import ox.channels.*
77

8-
import scala.util.control.NonFatal
9-
108
object KafkaSource:
119
private val logger = LoggerFactory.getLogger(classOf[KafkaSource.type])
1210

@@ -38,9 +36,9 @@ object KafkaSource:
3836
records.forEach(r => c.send(ReceivedMessage(kafkaConsumer, r)))
3937
}
4038
catch
41-
case NonFatal(e) =>
42-
logger.error("Exception when polling for records", e)
43-
c.errorSafe(e)
39+
case t: Throwable =>
40+
logger.error("Exception when polling for records", t)
41+
c.errorSafe(t)
4442
}
4543

4644
c

0 commit comments

Comments
 (0)