Skip to content

Commit

Permalink
KTOR-829 Support 100 Continue on client side (#3469)
Browse files Browse the repository at this point in the history
  • Loading branch information
marychatte authored Apr 18, 2023
1 parent 977cf8b commit 7b7ce23
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.utils.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.network.sockets.*
import io.ktor.network.tls.*
import io.ktor.util.*
import io.ktor.util.date.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
Expand Down Expand Up @@ -119,13 +121,62 @@ internal class Endpoint(
setupTimeout(callContext, request, timeout)

val requestTime = GMTDate()
writeRequest(request, output, callContext, proxy != null)
return readResponse(requestTime, request, input, originOutput, callContext)
val overProxy = proxy != null

return if (expectContinue(request.headers[HttpHeaders.Expect], request.body)) {
processExpectContinue(request, input, output, originOutput, callContext, requestTime, overProxy)
} else {
writeRequest(request, output, callContext, overProxy)
readResponse(requestTime, request, input, originOutput, callContext)
}
} catch (cause: Throwable) {
throw cause.mapToKtor(request)
}
}

@OptIn(InternalCoroutinesApi::class)
private suspend fun processExpectContinue(
request: HttpRequestData,
input: ByteReadChannel,
output: ByteWriteChannel,
originOutput: ByteWriteChannel,
callContext: CoroutineContext,
requestTime: GMTDate,
overProxy: Boolean,
) = withContext(callContext) {
writeHeaders(request, output, overProxy)

val responseReady = withTimeoutOrNull(CONTINUE_RESPONSE_TIMEOUT_MILLIS) {
input.awaitContent()
}

if (responseReady != null) {
val response = readResponse(requestTime, request, input, originOutput, callContext)
when (response.statusCode) {
HttpStatusCode.ExpectationFailed -> {
val newRequest = HttpRequestBuilder().apply {
takeFrom(request)
headers.remove(HttpHeaders.Expect)
}.build()
writeRequest(newRequest, output, callContext, overProxy)
}

HttpStatusCode.Continue -> {
writeBody(request, output, callContext)
}

else -> {
output.close()
return@withContext response
}
}
} else {
writeBody(request, output, callContext)
}

return@withContext readResponse(requestTime, request, input, originOutput, callContext)
}

private suspend fun createPipeline(request: HttpRequestData) {
val connection = connect(request)

Expand All @@ -141,6 +192,7 @@ internal class Endpoint(
pipeline.pipelineContext.invokeOnCompletion { releaseConnection() }
}

@Suppress("UNUSED_EXPRESSION")
private suspend fun connect(requestData: HttpRequestData): Connection {
val connectAttempts = config.endpoint.connectAttempts
val (connectTimeout, socketTimeout) = retrieveTimeouts(requestData)
Expand Down Expand Up @@ -241,6 +293,10 @@ internal class Endpoint(
override fun close() {
timeout.cancel()
}

companion object {
const val CONTINUE_RESPONSE_TIMEOUT_MILLIS = 1000L
}
}

@OptIn(DelicateCoroutinesApi::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ internal suspend fun writeRequest(
overProxy: Boolean,
closeChannel: Boolean = true
) = withContext(callContext) {
writeHeaders(request, output, overProxy, closeChannel)
writeBody(request, output, callContext)
}

@OptIn(InternalAPI::class)
internal suspend fun writeHeaders(
request: HttpRequestData,
output: ByteWriteChannel,
overProxy: Boolean,
closeChannel: Boolean = true
) {
val builder = RequestResponseBuilder()

val method = request.method
Expand All @@ -40,7 +51,8 @@ internal suspend fun writeRequest(
val contentLength = headers[HttpHeaders.ContentLength] ?: body.contentLength?.toString()
val contentEncoding = headers[HttpHeaders.TransferEncoding]
val responseEncoding = body.headers[HttpHeaders.TransferEncoding]
val chunked = contentLength == null || responseEncoding == "chunked" || contentEncoding == "chunked"
val chunked = isChunked(contentLength, responseEncoding, contentEncoding)
val expected = headers[HttpHeaders.Expect]

try {
val normalizedUrl = if (url.pathSegments.isEmpty()) URLBuilder(url).apply { encodedPath = "/" }.build() else url
Expand All @@ -64,7 +76,7 @@ internal suspend fun writeRequest(
}

mergeHeaders(headers, body) { key, value ->
if (key == HttpHeaders.ContentLength) return@mergeHeaders
if (key == HttpHeaders.ContentLength || key == HttpHeaders.Expect) return@mergeHeaders

builder.headerLine(key, value)
}
Expand All @@ -73,6 +85,10 @@ internal suspend fun writeRequest(
builder.headerLine(HttpHeaders.TransferEncoding, "chunked")
}

if (expectContinue(expected, body)) {
builder.headerLine(HttpHeaders.Expect, expected!!)
}

builder.emptyLine()
output.writePacket(builder.build())
output.flush()
Expand All @@ -84,19 +100,31 @@ internal suspend fun writeRequest(
} finally {
builder.release()
}
}

if (body is OutgoingContent.NoContent) {
internal suspend fun writeBody(
request: HttpRequestData,
output: ByteWriteChannel,
callContext: CoroutineContext,
closeChannel: Boolean = true
) {
if (request.body is OutgoingContent.NoContent) {
if (closeChannel) output.close()
return@withContext
return
}

val contentLength = request.headers[HttpHeaders.ContentLength] ?: request.body.contentLength?.toString()
val contentEncoding = request.headers[HttpHeaders.TransferEncoding]
val responseEncoding = request.body.headers[HttpHeaders.TransferEncoding]
val chunked = isChunked(contentLength, responseEncoding, contentEncoding)

val chunkedJob: EncoderJob? = if (chunked) encodeChunked(output, callContext) else null
val channel = chunkedJob?.channel ?: output

val scope = CoroutineScope(callContext + CoroutineName("Request body writer"))
scope.launch {
try {
when (body) {
when (val body = request.body) {
is OutgoingContent.NoContent -> return@launch
is OutgoingContent.ByteArrayContent -> channel.writeFully(body.bytes())
is OutgoingContent.ReadChannelContent -> body.readFrom().copyAndClose(channel)
Expand Down Expand Up @@ -152,6 +180,7 @@ internal suspend fun readResponse(
status.isInformational() -> {
ByteReadChannel.Empty
}

else -> {
val coroutineScope = CoroutineScope(callContext + CoroutineName("Response"))
val httpBodyParser = coroutineScope.writer(autoFlush = true) {
Expand Down Expand Up @@ -252,3 +281,12 @@ internal fun ByteWriteChannel.handleHalfClosed(
coroutineContext: CoroutineContext,
propagateClose: Boolean
): ByteWriteChannel = if (propagateClose) this else withoutClosePropagation(coroutineContext)

internal fun isChunked(
contentLength: String?,
responseEncoding: String?,
contentEncoding: String?
) = contentLength == null || responseEncoding == "chunked" || contentEncoding == "chunked"

internal fun expectContinue(expectHeader: String?, body: OutgoingContent) =
expectHeader != null && body !is OutgoingContent.NoContent
173 changes: 173 additions & 0 deletions ktor-client/ktor-client-cio/jvmAndNix/test/CIOEngineTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@
import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.websocket.*
import io.ktor.client.request.*
import io.ktor.client.tests.utils.*
import io.ktor.http.*
import io.ktor.network.selector.*
import io.ktor.network.sockets.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import io.ktor.websocket.*
import kotlinx.coroutines.*
import kotlin.test.*

class CIOEngineTest {

private val selectorManager = SelectorManager()

@Test
fun testRequestTimeoutIgnoredWithWebSocket(): Unit = runBlocking {
val client = HttpClient(CIO) {
Expand All @@ -34,4 +42,169 @@ class CIOEngineTest {

assertTrue(received)
}

@Test
fun testExpectHeader(): Unit = runBlocking {
val body = "Hello World"

withServerSocket { socket ->
val client = HttpClient(CIO)
launch {
sendExpectRequest(socket, client, body).apply {
assertEquals(HttpStatusCode.OK, status)
}
}

socket.accept().use {
val readChannel = it.openReadChannel()
val writeChannel = it.openWriteChannel()

val headers = readAvailableLines(readChannel)
assertTrue(headers.contains(EXPECT_HEADER))
assertFalse(headers.contains(body))

writeContinueResponse(writeChannel)
val actualBody = readAvailableLine(readChannel)
assertEquals(body, actualBody)
writeOkResponse(writeChannel)
}
}
}

@Test
fun testNoExpectHeaderIfNoBody(): Unit = runBlocking {
withServerSocket { socket ->
val client = HttpClient(CIO)
launch {
sendExpectRequest(socket, client).apply {
assertEquals(HttpStatusCode.OK, status)
}
}

socket.accept().use {
val readChannel = it.openReadChannel()
val writeChannel = it.openWriteChannel()

val headers = readAvailableLines(readChannel)
assertFalse(headers.contains(EXPECT_HEADER))
writeOkResponse(writeChannel)
}
}
}

@Test
fun testDontWaitForContinueResponse(): Unit = runBlocking {
val body = "Hello World\n"

withServerSocket { socket ->
val client = HttpClient(CIO) {
engine {
requestTimeout = 0
}
}
launch {
sendExpectRequest(socket, client, body).apply {
assertEquals(HttpStatusCode.OK, status)
}
}

socket.accept().use {
val readChannel = it.openReadChannel()
val writeChannel = it.openWriteChannel()

val headers = readAvailableLines(readChannel)
delay(2000)
val actualBody = readAvailableLine(readChannel)
assertTrue(headers.contains(EXPECT_HEADER))
assertEquals(body, actualBody)
writeOkResponse(writeChannel)
}
}
}

@Test
fun testRepeatRequestAfterExpectationFailed(): Unit = runBlocking {
val body = "Hello World"

withServerSocket { socket ->
val client = HttpClient(CIO)
launch {
sendExpectRequest(socket, client, body).apply {
assertEquals(HttpStatusCode.OK, status)
}
}

socket.accept().use {
val readChannel = it.openReadChannel()
val writeChannel = it.openWriteChannel()

val headers = readAvailableLines(readChannel)
assertTrue(headers.contains(EXPECT_HEADER))
writeExpectationFailedResponse(writeChannel)

delay(100) // because channel.flush() happens between writing headers and body
val newRequest = readAvailableLines(readChannel)
assertFalse(newRequest.contains(EXPECT_HEADER))
assertTrue(newRequest.contains(body))
writeOkResponse(writeChannel)
}
}
}

private suspend fun sendExpectRequest(socket: ServerSocket, client: HttpClient, body: String? = null) =
client.post {
val serverPort = (socket.localAddress as InetSocketAddress).port
url(host = TEST_SERVER_SOCKET_HOST, port = serverPort, path = "/")
header(HttpHeaders.Expect, "100-continue")
if (body != null) setBody(body)
}

private suspend fun readAvailableLine(channel: ByteReadChannel): String {
val buffer = ByteArray(1024)
val length = channel.readAvailable(buffer)
return String(buffer, length = length)
}

private suspend fun readAvailableLines(channel: ByteReadChannel): List<String> {
return readAvailableLine(channel).split("\r\n")
}

private suspend fun writeContinueResponse(channel: ByteWriteChannel) {
channel.apply {
writeStringUtf8("HTTP/1.1 100 Continue\r\n")
writeStringUtf8("\r\n")
flush()
}
}

private suspend fun writeOkResponse(channel: ByteWriteChannel) {
channel.apply {
writeStringUtf8("HTTP/1.1 200 Ok\r\n")
writeStringUtf8("Content-Length: 0\r\n")
writeStringUtf8("\r\n")
flush()
}
}

private suspend fun writeExpectationFailedResponse(channel: ByteWriteChannel) {
channel.apply {
writeStringUtf8("HTTP/1.1 417 Expectation Failed\r\n")
writeStringUtf8("Content-Length: 0\r\n")
writeStringUtf8("\r\n")
flush()
}
}

private suspend fun withServerSocket(block: suspend (ServerSocket) -> Unit) {
selectorManager.use {
aSocket(it).tcp().bind(TEST_SERVER_SOCKET_HOST, 0).use { socket ->
block(socket)
}
}
}

companion object {
private const val TEST_SERVER_SOCKET_HOST = "0.0.0.0"
private const val EXPECT_HEADER = "Expect: 100-continue"
}
}

0 comments on commit 7b7ce23

Please sign in to comment.