Skip to content

Commit

Permalink
KRPC-101 Check if the entire stream is not already closed. (#158)
Browse files Browse the repository at this point in the history
* KRPC-101 Check if the entire stream is not already closed.
In such case, the incomingChannels get cleared and closedStreams don't contain the streamId which leads to deadlock

* Fix import and yarn.lock

---------

Co-authored-by: Alexander Sysoev <[email protected]>
  • Loading branch information
pikinier20 and Mr3zee authored Aug 9, 2024
1 parent 582aa99 commit 45de74d
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions core/src/commonMain/kotlin/kotlinx/rpc/internal/RPCStreamContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.selects.select
import kotlinx.rpc.RPCConfig
import kotlinx.rpc.StreamScope
import kotlinx.rpc.internal.map.ConcurrentHashMap
Expand Down Expand Up @@ -55,9 +56,10 @@ public class RPCStreamContext(
private companion object {
private const val STREAM_ID_PREFIX = "stream:"
}
private val closed = CompletableDeferred<Unit>()

// thread-safe set
private val closedStreams = ConcurrentHashMap<String, Unit>()
private val closedStreams = ConcurrentHashMap<String, CompletableDeferred<Unit>>()

@InternalRPCApi
public inline fun launchIf(
Expand Down Expand Up @@ -168,7 +170,7 @@ public class RPCStreamContext(
fun onClose() {
incoming.cancel()

closedStreams.put(streamId, Unit)
closedStreams[streamId] = Unit
incomingChannels.remove(streamId)?.complete(null)
incomingStreams.remove(streamId)
}
Expand Down Expand Up @@ -242,27 +244,31 @@ public class RPCStreamContext(
}

public suspend fun send(message: RPCCallMessage.StreamMessage, serialFormat: SerialFormat) {
val info = incomingStreams.getDeferred(message.streamId).await()
val info: RPCStreamCall? = select {
incomingStreams.getDeferred(message.streamId).onAwait { it }
closedStreams.getDeferred(message.streamId).onAwait { null }
closed.onAwait { null }
}
if (info == null) return
val result = decodeMessageData(serialFormat, info.elementSerializer, message)
incomingChannelOf(message.streamId)?.send(result)
val channel = incomingChannelOf(message.streamId)
channel?.send(result)
}

private suspend fun incomingChannelOf(streamId: String): Channel<Any?>? {
if (closedStreams.containsKey(streamId)) {
return null
return select {
incomingChannels.getDeferred(streamId).onAwait { it }
closedStreams.getDeferred(streamId).onAwait { null }
closed.onAwait { null }
}

return incomingChannels.getDeferred(streamId).await()
}

private var closed = false

private fun close(cause: Throwable?) {
if (closed) {
if (closed.isCompleted) {
return
}

closed = true
closed.complete(Unit)

if (incomingChannelsInitialized) {
for (channel in incomingChannels.values) {
Expand Down

0 comments on commit 45de74d

Please sign in to comment.