@@ -11,6 +11,7 @@ import scala.collection.immutable.ArraySeq
11
11
import scala .concurrent .ExecutionContext
12
12
import scala .concurrent .Future
13
13
import scala .concurrent .duration .Duration
14
+ import scala .runtime .AbstractFunction1
14
15
import scala .util .*
15
16
import scala .util .control .NonFatal
16
17
import scala .util .control .NoStackTrace
@@ -171,24 +172,33 @@ object Fibers extends Joins[Fibers]:
171
172
}
172
173
173
174
def race [T ](l : Seq [T < Fibers ])(using f : Flat [T < Fibers ]): T < Fibers =
174
- l.size match
175
- case 0 => IOs .fail(" Can't race an empty list." )
176
- case 1 => l(0 )
177
- case _ =>
178
- Fibers .get(raceFiber[T ](l))
175
+ Fibers .get(raceFiber[T ](l))
179
176
180
177
def raceFiber [T ](l : Seq [T < Fibers ])(using f : Flat [T < Fibers ]): Fiber [T ] < IOs =
181
178
l.size match
182
179
case 0 => IOs .fail(" Can't race an empty list." )
183
180
case 1 => Fibers .run(l(0 ))
184
- case _ =>
181
+ case size =>
185
182
Locals .save.map { st =>
186
183
IOs {
187
- val p = new IOPromise [T ]
184
+ class State extends AbstractFunction1 [T < IOs , Unit ]:
185
+ val p = new IOPromise [T ]
186
+ val pending = new AtomicInteger (size)
187
+ def apply (v : T < IOs ): Unit =
188
+ val last = pending.decrementAndGet() == 0
189
+ try discard(p.complete(IOs .run(v)))
190
+ catch
191
+ case ex if (NonFatal (ex)) =>
192
+ if last then discard(p.complete(IOs .fail(ex)))
193
+ end try
194
+ end apply
195
+ end State
196
+ val state = new State
197
+ import state .*
188
198
foreach(l) { (i, io) =>
189
199
val f = IOTask (IOs (io), st)
190
200
p.interrupts(f)
191
- f.onComplete(v => discard(p.complete(v)) )
201
+ f.onComplete(state )
192
202
}
193
203
Promise (p)
194
204
}
0 commit comments