|
| 1 | +/* |
| 2 | + * Copyright 2015-2024 the original author or authors. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package io.rsocket.kotlin.connection |
| 18 | + |
| 19 | +import io.ktor.utils.io.core.* |
| 20 | +import io.rsocket.kotlin.* |
| 21 | +import io.rsocket.kotlin.frame.* |
| 22 | +import io.rsocket.kotlin.internal.* |
| 23 | +import io.rsocket.kotlin.internal.io.* |
| 24 | +import io.rsocket.kotlin.operation.* |
| 25 | +import io.rsocket.kotlin.payload.* |
| 26 | +import io.rsocket.kotlin.transport.* |
| 27 | +import kotlinx.coroutines.* |
| 28 | +import kotlinx.coroutines.flow.* |
| 29 | +import kotlin.coroutines.* |
| 30 | + |
| 31 | +// TODO: rename to just `Connection` after root `Connection` will be dropped |
| 32 | +@RSocketTransportApi |
| 33 | +internal abstract class Connection2( |
| 34 | + protected val frameCodec: FrameCodec, |
| 35 | + // requestContext |
| 36 | + final override val coroutineContext: CoroutineContext, |
| 37 | +) : RSocket, Closeable { |
| 38 | + |
| 39 | + // connection establishment part |
| 40 | + |
| 41 | + abstract suspend fun establishConnection(handler: ConnectionEstablishmentHandler): ConnectionConfig |
| 42 | + |
| 43 | + // setup completed, start handling requests |
| 44 | + abstract suspend fun handleConnection(inbound: ConnectionInbound) |
| 45 | + |
| 46 | + // connection part |
| 47 | + |
| 48 | + protected abstract suspend fun sendConnectionFrame(frame: ByteReadPacket) |
| 49 | + private suspend fun sendConnectionFrame(frame: Frame): Unit = sendConnectionFrame(frameCodec.encodeFrame(frame)) |
| 50 | + |
| 51 | + suspend fun sendError(cause: Throwable) { |
| 52 | + sendConnectionFrame(ErrorFrame(0, cause)) |
| 53 | + } |
| 54 | + |
| 55 | + private suspend fun sendMetadataPush(metadata: ByteReadPacket) { |
| 56 | + sendConnectionFrame(MetadataPushFrame(metadata)) |
| 57 | + } |
| 58 | + |
| 59 | + suspend fun sendKeepAlive(respond: Boolean, data: ByteReadPacket, lastPosition: Long) { |
| 60 | + sendConnectionFrame(KeepAliveFrame(respond, lastPosition, data)) |
| 61 | + } |
| 62 | + |
| 63 | + // operations part |
| 64 | + |
| 65 | + protected abstract fun launchRequest(requestPayload: Payload, operation: RequesterOperation): Job |
| 66 | + private suspend fun ensureActiveOrClose(closeable: Closeable) { |
| 67 | + currentCoroutineContext().ensureActive { closeable.close() } |
| 68 | + coroutineContext.ensureActive { closeable.close() } |
| 69 | + } |
| 70 | + |
| 71 | + final override suspend fun metadataPush(metadata: ByteReadPacket) { |
| 72 | + ensureActiveOrClose(metadata) |
| 73 | + sendMetadataPush(metadata) |
| 74 | + } |
| 75 | + |
| 76 | + final override suspend fun fireAndForget(payload: Payload) { |
| 77 | + ensureActiveOrClose(payload) |
| 78 | + |
| 79 | + suspendCancellableCoroutine { cont -> |
| 80 | + val requestJob = launchRequest( |
| 81 | + requestPayload = payload, |
| 82 | + operation = RequesterFireAndForgetOperation(cont) |
| 83 | + ) |
| 84 | + cont.invokeOnCancellation { cause -> |
| 85 | + requestJob.cancel("Request was cancelled", cause) |
| 86 | + } |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + final override suspend fun requestResponse(payload: Payload): Payload { |
| 91 | + ensureActiveOrClose(payload) |
| 92 | + |
| 93 | + val responseDeferred = CompletableDeferred<Payload>() |
| 94 | + |
| 95 | + val requestJob = launchRequest( |
| 96 | + requestPayload = payload, |
| 97 | + operation = RequesterRequestResponseOperation(responseDeferred) |
| 98 | + ) |
| 99 | + |
| 100 | + try { |
| 101 | + responseDeferred.join() |
| 102 | + } catch (cause: Throwable) { |
| 103 | + requestJob.cancel("Request was cancelled", cause) |
| 104 | + throw cause |
| 105 | + } |
| 106 | + return responseDeferred.await() |
| 107 | + } |
| 108 | + |
| 109 | + @OptIn(ExperimentalStreamsApi::class) |
| 110 | + final override fun requestStream( |
| 111 | + payload: Payload, |
| 112 | + ): Flow<Payload> = payloadFlow { strategy, initialRequest -> |
| 113 | + ensureActiveOrClose(payload) |
| 114 | + |
| 115 | + val responsePayloads = PayloadChannel() |
| 116 | + |
| 117 | + val requestJob = launchRequest( |
| 118 | + requestPayload = payload, |
| 119 | + operation = RequesterRequestStreamOperation(initialRequest, responsePayloads) |
| 120 | + ) |
| 121 | + |
| 122 | + throw try { |
| 123 | + responsePayloads.consumeInto(this, strategy) |
| 124 | + } catch (cause: Throwable) { |
| 125 | + requestJob.cancel("Request was cancelled", cause) |
| 126 | + throw cause |
| 127 | + } ?: return@payloadFlow |
| 128 | + } |
| 129 | + |
| 130 | + @OptIn(ExperimentalStreamsApi::class) |
| 131 | + final override fun requestChannel( |
| 132 | + initPayload: Payload, |
| 133 | + payloads: Flow<Payload>, |
| 134 | + ): Flow<Payload> = payloadFlow { strategy, initialRequest -> |
| 135 | + ensureActiveOrClose(initPayload) |
| 136 | + |
| 137 | + val responsePayloads = PayloadChannel() |
| 138 | + |
| 139 | + val requestJob = launchRequest( |
| 140 | + initPayload, |
| 141 | + RequesterRequestChannelOperation(initialRequest, payloads, responsePayloads) |
| 142 | + ) |
| 143 | + |
| 144 | + throw try { |
| 145 | + responsePayloads.consumeInto(this, strategy) |
| 146 | + } catch (cause: Throwable) { |
| 147 | + requestJob.cancel("Request was cancelled", cause) |
| 148 | + throw cause |
| 149 | + } ?: return@payloadFlow |
| 150 | + } |
| 151 | +} |
0 commit comments