From 41c51f1c61ac957e439ced4302f09160c850787e Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Fri, 15 Jan 2021 11:59:26 -0500 Subject: [PATCH] feat(transport): Fix TLS accept w/ peer certs (#535) * feat(transport): Fix TLS accept w/ peer certs * fix unused var * fix feature flag imports * spawn accept task --- examples/src/tls_client_auth/server.rs | 8 +- tonic/Cargo.toml | 1 + tonic/src/transport/server/conn.rs | 30 ++-- tonic/src/transport/server/incoming.rs | 209 ++++++++++++------------- tonic/src/transport/server/mod.rs | 2 +- tonic/src/transport/service/tls.rs | 6 +- 6 files changed, 118 insertions(+), 138 deletions(-) diff --git a/examples/src/tls_client_auth/server.rs b/examples/src/tls_client_auth/server.rs index 2a719f6f4..d1d88d789 100644 --- a/examples/src/tls_client_auth/server.rs +++ b/examples/src/tls_client_auth/server.rs @@ -17,9 +17,11 @@ pub struct EchoServer; #[tonic::async_trait] impl pb::echo_server::Echo for EchoServer { async fn unary_echo(&self, request: Request) -> EchoResult { - if let Some(certs) = request.peer_certs() { - println!("Got {} peer certs!", certs.len()); - } + let certs = request + .peer_certs() + .expect("Client did not send its certs!"); + + println!("Got {} peer certs!", certs.len()); let message = request.into_inner().message; Ok(Response::new(EchoResponse { message })) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index ef1b65062..540c17c04 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -31,6 +31,7 @@ transport = [ "tokio", "tower", "tracing-futures", + "tokio/macros" ] tls = ["transport", "tokio-rustls"] tls-roots = ["tls", "rustls-native-certs"] diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs index d7d458840..f5bbcfc08 100644 --- a/tonic/src/transport/server/conn.rs +++ b/tonic/src/transport/server/conn.rs @@ -1,11 +1,9 @@ -#[cfg(feature = "tls")] -use super::TlsStream; use crate::transport::Certificate; use hyper::server::conn::AddrStream; use std::net::SocketAddr; use tokio::net::TcpStream; #[cfg(feature = "tls")] -use tokio_rustls::rustls::Session; +use tokio_rustls::{rustls::Session, server::TlsStream}; /// Trait that connected IO resources implement. /// @@ -39,24 +37,20 @@ impl Connected for TcpStream { #[cfg(feature = "tls")] impl Connected for TlsStream { fn remote_addr(&self) -> Option { - if let Some((inner, _)) = self.get_ref() { - inner.remote_addr() - } else { - None - } + let (inner, _) = self.get_ref(); + + inner.remote_addr() } fn peer_certs(&self) -> Option> { - if let Some((_, session)) = self.get_ref() { - if let Some(certs) = session.get_peer_certificates() { - let certs = certs - .into_iter() - .map(|c| Certificate::from_pem(c.0)) - .collect(); - Some(certs) - } else { - None - } + let (_, session) = self.get_ref(); + + if let Some(certs) = session.get_peer_certificates() { + let certs = certs + .into_iter() + .map(|c| Certificate::from_pem(c.0)) + .collect(); + Some(certs) } else { None } diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index 2e1112863..8e8103e90 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -14,10 +14,10 @@ use std::{ }; use tokio::io::{AsyncRead, AsyncWrite}; -#[cfg_attr(not(feature = "tls"), allow(unused_variables))] +#[cfg(not(feature = "tls"))] pub(crate) fn tcp_incoming( incoming: impl Stream>, - server: Server, + _server: Server, ) -> impl Stream> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, @@ -26,145 +26,130 @@ where async_stream::try_stream! { futures_util::pin_mut!(incoming); + while let Some(stream) = incoming.try_next().await? { - #[cfg(feature = "tls")] - { - if let Some(tls) = &server.tls { - let io = tls.accept(stream); - yield ServerIo::new(io); - continue; - } - } yield ServerIo::new(stream); } } } -pub(crate) struct TcpIncoming { - inner: AddrIncoming, -} +#[cfg(feature = "tls")] +pub(crate) fn tcp_incoming( + incoming: impl Stream>, + server: Server, +) -> impl Stream> +where + IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IE: Into, +{ + async_stream::try_stream! { + futures_util::pin_mut!(incoming); -impl TcpIncoming { - pub(crate) fn new( - addr: SocketAddr, - nodelay: bool, - keepalive: Option, - ) -> Result { - let mut inner = AddrIncoming::bind(&addr)?; - inner.set_nodelay(nodelay); - inner.set_keepalive(keepalive); - Ok(TcpIncoming { inner }) - } -} + #[cfg(feature = "tls")] + let mut tasks = futures_util::stream::futures_unordered::FuturesUnordered::new(); -impl Stream for TcpIncoming { - type Item = Result; + loop { + match select(&mut incoming, &mut tasks).await { + SelectOutput::Incoming(stream) => { + if let Some(tls) = &server.tls { + let tls = tls.clone(); - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_accept(cx) - } -} + let accept = tokio::spawn(async move { + let io = tls.accept(stream).await?; + Ok(ServerIo::new(io)) + }); -// tokio_rustls::server::TlsStream doesn't expose constructor methods, -// so we have to TlsAcceptor::accept and handshake to have access to it -// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first -#[cfg(feature = "tls")] -pub(crate) struct TlsStream { - state: State, -} + tasks.push(accept); + } else { + yield ServerIo::new(stream); + } + } -#[cfg(feature = "tls")] -enum State { - Handshaking(tokio_rustls::Accept), - Streaming(tokio_rustls::server::TlsStream), -} + SelectOutput::Io(io) => { + yield io; + } -#[cfg(feature = "tls")] -impl TlsStream { - pub(crate) fn new(accept: tokio_rustls::Accept) -> Self { - TlsStream { - state: State::Handshaking(accept), - } - } + SelectOutput::Err(e) => { + tracing::error!(message = "Accept loop error.", error = %e); + } - pub(crate) fn get_ref(&self) -> Option<(&IO, &tokio_rustls::rustls::ServerSession)> { - if let State::Streaming(tls) = &self.state { - Some(tls.get_ref()) - } else { - None + SelectOutput::Done => { + break; + } + } } } } #[cfg(feature = "tls")] -impl AsyncRead for TlsStream +async fn select( + incoming: &mut (impl Stream> + Unpin), + tasks: &mut futures_util::stream::futures_unordered::FuturesUnordered< + tokio::task::JoinHandle>, + >, +) -> SelectOutput where - IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IE: Into, { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - use std::future::Future; - - let pin = self.get_mut(); - match pin.state { - State::Handshaking(ref mut accept) => { - match futures_core::ready!(Pin::new(accept).poll(cx)) { - Ok(mut stream) => { - let result = Pin::new(&mut stream).poll_read(cx, buf); - pin.state = State::Streaming(stream); - result - } - Err(err) => Poll::Ready(Err(err)), - } + use futures_util::StreamExt; + + if tasks.is_empty() { + return match incoming.try_next().await { + Ok(Some(stream)) => SelectOutput::Incoming(stream), + Ok(None) => SelectOutput::Done, + Err(e) => SelectOutput::Err(e.into()), + }; + } + + tokio::select! { + stream = incoming.try_next() => { + match stream { + Ok(Some(stream)) => SelectOutput::Incoming(stream), + Ok(None) => SelectOutput::Done, + Err(e) => SelectOutput::Err(e.into()), } - State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), } - } -} -#[cfg(feature = "tls")] -impl AsyncWrite for TlsStream -where - IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - use std::future::Future; - - let pin = self.get_mut(); - match pin.state { - State::Handshaking(ref mut accept) => { - match futures_core::ready!(Pin::new(accept).poll(cx)) { - Ok(mut stream) => { - let result = Pin::new(&mut stream).poll_write(cx, buf); - pin.state = State::Streaming(stream); - result - } - Err(err) => Poll::Ready(Err(err)), - } + accept = tasks.next() => { + match accept.expect("FuturesUnordered stream should never end") { + Ok(Ok(io)) => SelectOutput::Io(io), + Ok(Err(e)) => SelectOutput::Err(e), + Err(e) => SelectOutput::Err(e.into()), } - State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), } } +} - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.state { - State::Handshaking(_) => Poll::Ready(Ok(())), - State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), - } +#[cfg(feature = "tls")] +enum SelectOutput { + Incoming(A), + Io(ServerIo), + Err(crate::Error), + Done, +} + +pub(crate) struct TcpIncoming { + inner: AddrIncoming, +} + +impl TcpIncoming { + pub(crate) fn new( + addr: SocketAddr, + nodelay: bool, + keepalive: Option, + ) -> Result { + let mut inner = AddrIncoming::bind(&addr)?; + inner.set_nodelay(nodelay); + inner.set_keepalive(keepalive); + Ok(TcpIncoming { inner }) } +} - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.state { - State::Handshaking(_) => Poll::Ready(Ok(())), - State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), - } +impl Stream for TcpIncoming { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_accept(cx) } } diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 569be8ace..3a2bb4cbe 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -16,7 +16,7 @@ use super::service::TlsAcceptor; use incoming::TcpIncoming; #[cfg(feature = "tls")] -pub(crate) use incoming::TlsStream; +pub(crate) use tokio_rustls::server::TlsStream; #[cfg(feature = "tls")] use crate::transport::Error; diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index f6e248752..dd3a6965d 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -162,14 +162,12 @@ impl TlsAcceptor { }) } - pub(crate) fn accept(&self, io: IO) -> TlsStream + pub(crate) async fn accept(&self, io: IO) -> Result, crate::Error> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, { let acceptor = RustlsAcceptor::from(self.inner.clone()); - let accept = acceptor.accept(io); - - TlsStream::new(accept) + acceptor.accept(io).await.map_err(Into::into) } }