diff --git a/tokio-tcp/Cargo.toml b/tokio-tcp/Cargo.toml index a6c672f4205..f1c8722df3d 100644 --- a/tokio-tcp/Cargo.toml +++ b/tokio-tcp/Cargo.toml @@ -37,4 +37,4 @@ futures-core-preview = { version = "0.3.0-alpha.16", optional = true } [dev-dependencies] #env_logger = { version = "0.5", default-features = false } #net2 = "*" -#tokio = { version = "0.2.0", path = "../tokio" } +tokio = { version = "0.2.0", path = "../tokio" } diff --git a/tokio-tcp/src/lib.rs b/tokio-tcp/src/lib.rs index d88b50a4526..23fb1f01cd8 100644 --- a/tokio-tcp/src/lib.rs +++ b/tokio-tcp/src/lib.rs @@ -34,6 +34,7 @@ macro_rules! ready { #[cfg(feature = "incoming")] mod incoming; mod listener; +pub mod split; mod stream; pub use self::listener::TcpListener; diff --git a/tokio-tcp/src/split.rs b/tokio-tcp/src/split.rs new file mode 100644 index 00000000000..ecd9640b774 --- /dev/null +++ b/tokio-tcp/src/split.rs @@ -0,0 +1,142 @@ +//! `TcpStream` split support. +//! +//! A `TcpStream` can be split into a `TcpStreamReadHalf` and a +//! `TcpStreamWriteHalf` with the `TcpStream::split` method. `TcpStreamReadHalf` +//! implements `AsyncRead` while `TcpStreamWriteHalf` implements `AsyncWrite`. +//! The two halves can be used concurrently, even from multiple tasks. +//! +//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized +//! split gives read and write halves that are faster and smaller, because they +//! do not use locks. They also provide access to the underlying `TcpStream` +//! after split, implementing `AsRef`. This allows you to call +//! `TcpStream` methods that takes `&self`, e.g., to get local and peer +//! addresses, to get and set socket options, and to shutdown the sockets. + +use super::TcpStream; +use bytes::{Buf, BufMut}; +use std::error::Error; +use std::fmt; +use std::io; +use std::net::Shutdown; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio_io::{AsyncRead, AsyncWrite}; + +/// Read half of a `TcpStream`. +#[derive(Debug)] +pub struct TcpStreamReadHalf(Arc); + +/// Write half of a `TcpStream`. +#[derive(Debug)] +pub struct TcpStreamWriteHalf(Arc); + +pub(crate) fn split(stream: TcpStream) -> (TcpStreamReadHalf, TcpStreamWriteHalf) { + let shared = Arc::new(stream); + ( + TcpStreamReadHalf(shared.clone()), + TcpStreamWriteHalf(shared), + ) +} + +/// Error indicating two halves were not from the same stream, and thus could +/// not be `reunite`d. +#[derive(Debug)] +pub struct ReuniteError(pub TcpStreamReadHalf, pub TcpStreamWriteHalf); + +impl fmt::Display for ReuniteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "tried to reunite halves that are not from the same stream" + ) + } +} + +impl Error for ReuniteError {} + +impl TcpStreamReadHalf { + /// Attempts to put the two "halves" of a `TcpStream` back together and + /// recover the original stream. Succeeds only if the two "halves" + /// originated from the same call to `TcpStream::split`. + pub fn reunite(self, other: TcpStreamWriteHalf) -> Result { + if Arc::ptr_eq(&self.0, &other.0) { + drop(other); + Ok(Arc::try_unwrap(self.0).unwrap()) + } else { + Err(ReuniteError(self, other)) + } + } +} + +impl TcpStreamWriteHalf { + /// Attempts to put the two "halves" of a `TcpStream` back together and + /// recover the original stream. Succeeds only if the two "halves" + /// originated from the same call to `TcpStream::split`. + pub fn reunite(self, other: TcpStreamReadHalf) -> Result { + other.reunite(self) + } +} + +impl AsRef for TcpStreamReadHalf { + fn as_ref(&self) -> &TcpStream { + &self.0 + } +} + +impl AsRef for TcpStreamWriteHalf { + fn as_ref(&self) -> &TcpStream { + &self.0 + } +} + +impl AsyncRead for TcpStreamReadHalf { + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.0.poll_read_priv(cx, buf) + } + + fn poll_read_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> { + self.0.poll_read_buf_priv(cx, buf) + } +} + +impl AsyncWrite for TcpStreamWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.0.poll_write_priv(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // tcp flush is a no-op + Poll::Ready(Ok(())) + } + + // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + self.0.shutdown(Shutdown::Write).into() + } + + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> { + self.0.poll_write_buf_priv(cx, buf) + } +} diff --git a/tokio-tcp/src/stream.rs b/tokio-tcp/src/stream.rs index ada486acdc0..f2253a5ef54 100644 --- a/tokio-tcp/src/stream.rs +++ b/tokio-tcp/src/stream.rs @@ -1,10 +1,11 @@ +use crate::split::{split, TcpStreamReadHalf, TcpStreamWriteHalf}; use bytes::{Buf, BufMut}; use iovec::IoVec; use mio; use std::convert::TryFrom; use std::fmt; use std::future::Future; -use std::io; +use std::io::{self, Read, Write}; use std::mem; use std::net::{self, Shutdown, SocketAddr}; use std::pin::Pin; @@ -712,37 +713,45 @@ impl TcpStream { let msg = "`TcpStream::try_clone()` is deprecated because it doesn't work as intended"; Err(io::Error::new(io::ErrorKind::Other, msg)) } -} - -impl TryFrom for mio::net::TcpStream { - type Error = io::Error; - /// Consumes value, returning the mio I/O object. + /// Split a `TcpStream` into a read half and a write half, which can be used + /// to read and write the stream concurrently. /// - /// See [`tokio_reactor::PollEvented::into_inner`] for more details about - /// resource deregistration that happens during the call. - fn try_from(value: TcpStream) -> Result { - value.io.into_inner() + /// See the module level documenation of [`split`](super::split) for more + /// details. + pub fn split(self) -> (TcpStreamReadHalf, TcpStreamWriteHalf) { + split(self) } -} - -// ===== impl Read / Write ===== -impl AsyncRead for TcpStream { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false - } - - fn poll_read( - mut self: Pin<&mut Self>, + // == Poll IO functions that takes `&self` == + // + // They are not public because (taken from the doc of `PollEvented`): + // + // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the + // caller must ensure that there are at most two tasks that use a + // `PollEvented` instance concurrently. One for reading and one for writing. + // While violating this requirement is "safe" from a Rust memory model point + // of view, it will result in unexpected behavior in the form of lost + // notifications and tasks hanging. + + pub(crate) fn poll_read_priv( + &self, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - Pin::new(&mut self.io).poll_read(cx, buf) + ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; + + match self.io.get_ref().read(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_read_ready(cx, mio::Ready::readable())?; + Poll::Pending + } + x => Poll::Ready(x), + } } - fn poll_read_buf( - self: Pin<&mut Self>, + pub(crate) fn poll_read_buf_priv( + &self, cx: &mut Context<'_>, buf: &mut B, ) -> Poll> { @@ -804,29 +813,25 @@ impl AsyncRead for TcpStream { Err(e) => Poll::Ready(Err(e)), } } -} -impl AsyncWrite for TcpStream { - fn poll_write( - mut self: Pin<&mut Self>, + pub(crate) fn poll_write_priv( + &self, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - Pin::new(&mut self.io).poll_write(cx, buf) - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - // tcp flush is a no-op - Poll::Ready(Ok(())) - } + ready!(self.io.poll_write_ready(cx))?; - fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + match self.io.get_ref().write(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_write_ready(cx)?; + Poll::Pending + } + x => Poll::Ready(x), + } } - fn poll_write_buf( - self: Pin<&mut Self>, + pub(crate) fn poll_write_buf_priv( + &self, cx: &mut Context<'_>, buf: &mut B, ) -> Poll> { @@ -856,6 +861,70 @@ impl AsyncWrite for TcpStream { } } +impl TryFrom for mio::net::TcpStream { + type Error = io::Error; + + /// Consumes value, returning the mio I/O object. + /// + /// See [`tokio_reactor::PollEvented::into_inner`] for more details about + /// resource deregistration that happens during the call. + fn try_from(value: TcpStream) -> Result { + value.io.into_inner() + } +} + +// ===== impl Read / Write ===== + +impl AsyncRead for TcpStream { + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { + false + } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.poll_read_priv(cx, buf) + } + + fn poll_read_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> { + self.poll_read_buf_priv(cx, buf) + } +} + +impl AsyncWrite for TcpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.poll_write_priv(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // tcp flush is a no-op + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> { + self.poll_write_buf_priv(cx, buf) + } +} + impl fmt::Debug for TcpStream { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.io.get_ref().fmt(f) diff --git a/tokio-tcp/tests/split.rs b/tokio-tcp/tests/split.rs new file mode 100644 index 00000000000..8b9701ae4dd --- /dev/null +++ b/tokio-tcp/tests/split.rs @@ -0,0 +1,25 @@ +#![feature(async_await)] + +use tokio_tcp::{TcpListener, TcpStream}; + +#[tokio::test] +async fn split_reunite() -> std::io::Result<()> { + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap())?; + let stream = TcpStream::connect(&listener.local_addr()?).await?; + + let (r, w) = stream.split(); + assert!(r.reunite(w).is_ok()); + Ok(()) +} + +#[tokio::test] +async fn split_reunite_error() -> std::io::Result<()> { + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap())?; + let stream = TcpStream::connect(&listener.local_addr()?).await?; + let stream1 = TcpStream::connect(&listener.local_addr()?).await?; + + let (r, _) = stream.split(); + let (_, w) = stream1.split(); + assert!(r.reunite(w).is_err()); + Ok(()) +} diff --git a/tokio/examples/proxy.rs b/tokio/examples/proxy.rs new file mode 100644 index 00000000000..7893818b1f9 --- /dev/null +++ b/tokio/examples/proxy.rs @@ -0,0 +1,99 @@ +//! A proxy that forwards data to another server and forwards that server's +//! responses back to clients. +//! +//! Because the Tokio runtime uses a thread pool, each TCP connection is +//! processed concurrently with all other TCP connections across multiple +//! threads. +//! +//! You can showcase this by running this in one terminal: +//! +//! cargo run --example proxy +//! +//! This in another terminal +//! +//! cargo run --example echo +//! +//! And finally this in another terminal +//! +//! cargo run --example connect 127.0.0.1:8081 +//! +//! This final terminal will connect to our proxy, which will in turn connect to +//! the echo server, and you'll be able to see data flowing between them. + +#![feature(async_await)] + +use futures::future::try_join; +use futures::prelude::StreamExt; +use std::env; +use std::net::SocketAddr; +use tokio; +use tokio::io::AsyncReadExt; +use tokio::net::tcp::split::{TcpStreamReadHalf, TcpStreamWriteHalf}; +use tokio::net::tcp::{TcpListener, TcpStream}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let listen_addr = env::args().nth(1).unwrap_or("127.0.0.1:8081".to_string()); + let listen_addr = listen_addr.parse::()?; + + let server_addr = env::args().nth(2).unwrap_or("127.0.0.1:8080".to_string()); + let server_addr = server_addr.parse::()?; + + // Create a TCP listener which will listen for incoming connections. + let socket = TcpListener::bind(&listen_addr)?; + println!("Listening on: {}", listen_addr); + println!("Proxying to: {}", server_addr); + + let mut incoming = socket.incoming(); + loop { + let stream = incoming.next().await.unwrap()?; + tokio::spawn(async move { + match proxy_client(stream, server_addr).await { + Err(e) => { + eprintln!("Error: {}", e); + } + _ => (), + } + }); + } +} + +async fn proxy_client( + client_stream: TcpStream, + server_addr: SocketAddr, +) -> Result<(), std::io::Error> { + let server_stream = TcpStream::connect(&server_addr).await?; + + // Create separate read/write handles for the TCP clients that we're + // proxying data between. + // + // Note that while you can use `AsyncRead::split` for this operation, + // `TcpStream::split` gives you handles that are faster, smaller and allow + // proper shutdown operations. + let (client_r, client_w) = client_stream.split(); + let (server_r, server_w) = server_stream.split(); + + let client_to_server = copy_shutdown(client_r, server_w); + let server_to_client = copy_shutdown(server_r, client_w); + + // Run the two futures in parallel. + let (l1, l2) = try_join(client_to_server, server_to_client).await?; + println!("client wrote {} bytes and received {} bytes", l1, l2); + Ok(()) +} + +// Copy data from a read half to a write half. After the copy is done we +// indicate to the remote side that we've finished by shutting down the +// connection. +async fn copy_shutdown( + mut r: TcpStreamReadHalf, + mut w: TcpStreamWriteHalf, +) -> Result { + let l = r.copy(&mut w).await?; + + // Use this instead after `shutdown` is implemented in `AsyncWriteExt`: + // w.shutdown().await?; + w.as_ref().shutdown(std::net::Shutdown::Write)?; + + Ok(l) +} diff --git a/tokio/src/net.rs b/tokio/src/net.rs index 896bca3de86..3b37b5371a5 100644 --- a/tokio/src/net.rs +++ b/tokio/src/net.rs @@ -40,7 +40,7 @@ pub mod tcp { //! [`TcpListener`]: struct.TcpListener.html //! [incoming_method]: struct.TcpListener.html#method.incoming //! [`Incoming`]: struct.Incoming.html - pub use tokio_tcp::{TcpListener, TcpStream}; + pub use tokio_tcp::{split, TcpListener, TcpStream}; } #[cfg(feature = "tcp")] pub use self::tcp::{TcpListener, TcpStream};