Skip to content

Commit 426bd08

Browse files
authored
Merge pull request #230 from getkyo/fibers-race-improvement
improve Fibers.race to wait for the first successful result
2 parents f8f1b7c + 6e37a03 commit 426bd08

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

kyo-core/shared/src/main/scala/kyo/fibers.scala

+18-8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import scala.collection.immutable.ArraySeq
1111
import scala.concurrent.ExecutionContext
1212
import scala.concurrent.Future
1313
import scala.concurrent.duration.Duration
14+
import scala.runtime.AbstractFunction1
1415
import scala.util.*
1516
import scala.util.control.NonFatal
1617
import scala.util.control.NoStackTrace
@@ -171,24 +172,33 @@ object Fibers extends Joins[Fibers]:
171172
}
172173

173174
def race[T](l: Seq[T < Fibers])(using f: Flat[T < Fibers]): T < Fibers =
174-
l.size match
175-
case 0 => IOs.fail("Can't race an empty list.")
176-
case 1 => l(0)
177-
case _ =>
178-
Fibers.get(raceFiber[T](l))
175+
Fibers.get(raceFiber[T](l))
179176

180177
def raceFiber[T](l: Seq[T < Fibers])(using f: Flat[T < Fibers]): Fiber[T] < IOs =
181178
l.size match
182179
case 0 => IOs.fail("Can't race an empty list.")
183180
case 1 => Fibers.run(l(0))
184-
case _ =>
181+
case size =>
185182
Locals.save.map { st =>
186183
IOs {
187-
val p = new IOPromise[T]
184+
class State extends AbstractFunction1[T < IOs, Unit]:
185+
val p = new IOPromise[T]
186+
val pending = new AtomicInteger(size)
187+
def apply(v: T < IOs): Unit =
188+
val last = pending.decrementAndGet() == 0
189+
try discard(p.complete(IOs.run(v)))
190+
catch
191+
case ex if (NonFatal(ex)) =>
192+
if last then discard(p.complete(IOs.fail(ex)))
193+
end try
194+
end apply
195+
end State
196+
val state = new State
197+
import state.*
188198
foreach(l) { (i, io) =>
189199
val f = IOTask(IOs(io), st)
190200
p.interrupts(f)
191-
f.onComplete(v => discard(p.complete(v)))
201+
f.onComplete(state)
192202
}
193203
Promise(p)
194204
}

kyo-core/shared/src/test/scala/kyoTest/fibersTest.scala

+23-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class fibersTest extends KyoTest:
159159
assert(r == 1)
160160
}
161161
}
162-
"n" in runJVM {
162+
"multiple" in runJVM {
163163
val ac = new JAtomicInteger(0)
164164
val bc = new JAtomicInteger(0)
165165
def loop(i: Int, s: String): String < IOs =
@@ -177,6 +177,28 @@ class fibersTest extends KyoTest:
177177
assert(bc.get() <= Int.MaxValue)
178178
}
179179
}
180+
"waits for the first success" in runJVM {
181+
val ex = new Exception
182+
Fibers.race(
183+
Fibers.sleep(1.millis).andThen(42),
184+
IOs.fail[Int](ex)
185+
).map { r =>
186+
assert(r == 42)
187+
}
188+
}
189+
"returns the last failure if all fibers fail" in runJVM {
190+
val ex1 = new Exception
191+
val ex2 = new Exception
192+
IOs.attempt(
193+
Fibers.race(
194+
Fibers.sleep(1.millis).andThen(IOs.fail[Int](ex1)),
195+
IOs.fail[Int](ex2)
196+
)
197+
).map {
198+
r =>
199+
assert(r == Failure(ex1))
200+
}
201+
}
180202
}
181203

182204
"raceFiber" - {

0 commit comments

Comments
 (0)