From f4b5622530d8d3973e895d7268f3dea401abc1fb Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 31 Aug 2023 11:34:52 +0100 Subject: [PATCH] address comments --- proxy/src/http/websocket.rs | 2 +- proxy/src/protocol2.rs | 121 +++++++++++++++++++++++++++++------- proxy/src/proxy.rs | 2 +- 3 files changed, 99 insertions(+), 26 deletions(-) diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index c94467f14b3b..fa66df0469cc 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -309,7 +309,7 @@ pub async fn task_main( let make_svc = hyper::service::make_service_fn( |stream: &tokio_rustls::server::TlsStream>| { let (io, tls) = stream.get_ref(); - let peer_addr = io.client_socket().unwrap_or(io.inner.remote_addr()); + let peer_addr = io.client_addr().unwrap_or(io.inner.remote_addr()); let sni_name = tls.server_name().map(|s| s.to_string()); let conn_pool = conn_pool.clone(); diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 20e857ee80c7..1d8931be85fc 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -24,7 +24,7 @@ pin_project! { #[pin] pub inner: T, buf: BytesMut, - tlv_bytes: usize, + tlv_bytes: u16, state: ProxyParse, } } @@ -82,7 +82,7 @@ impl WithClientIp { } } - pub fn client_socket(&self) -> Option { + pub fn client_addr(&self) -> Option { match self.state { ProxyParse::Finished(socket) => Some(socket), _ => None, @@ -91,9 +91,20 @@ impl WithClientIp { } impl WithClientIp { - pub async fn wait_for_socket(&mut self) -> io::Result> { - let mut pin = Pin::new(self); - poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await + pub async fn wait_for_addr(&mut self) -> io::Result> { + match self.state { + ProxyParse::NotStarted => { + let mut pin = Pin::new(&mut *self); + let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?; + match addr { + Some(addr) => self.state = ProxyParse::Finished(addr), + None => self.state = ProxyParse::None, + } + Ok(addr) + } + ProxyParse::Finished(addr) => Ok(Some(addr)), + ProxyParse::None => Ok(None), + } } } @@ -111,7 +122,7 @@ impl WithClientIp { ) -> Poll>> { // The binary header format starts with a constant 12 bytes block containing the protocol signature : // \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A - while self.buf.len() <= 16 { + while self.buf.len() < 16 { let mut this = self.as_mut().project(); let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?; @@ -134,7 +145,10 @@ impl WithClientIp { let version = vc >> 4; let command = vc & 0b1111; if version != 2 { - return Poll::Ready(Ok(None)); + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "invalid proxy protocol version. expected version 2", + ))); } match command { // the connection was established on purpose by the proxy @@ -150,7 +164,12 @@ impl WithClientIp { 1 => {} // other values are unassigned and must not be emitted by senders. Receivers // must drop connections presenting unexpected values here. - _ => return Poll::Ready(Ok(None)), + _ => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "invalid proxy protocol command. expected local (0) or proxy (1)", + ))) + } }; // The 14th byte contains the transport protocol and address family. The highest 4 @@ -180,15 +199,21 @@ impl WithClientIp { // of bytes and must not assume zero is presented for LOCAL connections. When a // receiver accepts an incoming connection showing an UNSPEC address family or // protocol, it may or may not decide to log the address information if present. - let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap()) as usize; + let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap()); if remaining_length < address_length { - return Poll::Ready(Ok(None)); + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "invalid proxy protocol length. not enough to fit requested IP addresses", + ))); } - while self.buf.len() < 16 + address_length { + while self.buf.len() < 16 + address_length as usize { let mut this = self.as_mut().project(); if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 { - return Poll::Ready(Ok(None)); + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "stream closed while waiting for proxy protocol addresses", + ))); } } @@ -204,7 +229,7 @@ impl WithClientIp { // - destination layer 3 address in network byte order // - source layer 4 address if any, in network byte order (port) // - destination layer 4 address if any, in network byte order (port) - let addresses = this.buf.split_to(address_length); + let addresses = this.buf.split_to(address_length as usize); let socket = match address_length { 12 => { let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap(); @@ -220,9 +245,7 @@ impl WithClientIp { }; *this.tlv_bytes = remaining_length - address_length; - let discard = usize::min(*this.tlv_bytes, this.buf.len()); - *this.tlv_bytes -= discard; - this.buf.advance(discard); + self.as_mut().skip_tlv_inner(); Poll::Ready(Ok(socket)) } @@ -238,18 +261,29 @@ impl WithClientIp { } #[cold] - fn skip_tlv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); + fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.as_mut().project(); // we know that this.buf is empty debug_assert_eq!(this.buf.len(), 0); - let n = ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?); - let tlv_bytes_read = usize::min(n, *this.tlv_bytes); - *this.tlv_bytes -= tlv_bytes_read; - this.buf.advance(tlv_bytes_read); + this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize); + ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?); + self.skip_tlv_inner(); Poll::Ready(Ok(())) } + + fn skip_tlv_inner(self: Pin<&mut Self>) { + let tlv_bytes_read = match u16::try_from(self.buf.len()) { + // we read more than u16::MAX therefore we must have read the full tlv_bytes + Err(_) => self.tlv_bytes, + // we might not have read the full tlv bytes yet + Ok(n) => u16::min(n, self.tlv_bytes), + }; + let this = self.project(); + *this.tlv_bytes -= tlv_bytes_read; + this.buf.advance(tlv_bytes_read as usize); + } } impl AsyncRead for WithClientIp { @@ -282,6 +316,11 @@ impl AsyncRead for WithClientIp { let slice = this.buf.split_to(write).freeze(); buf.put_slice(&slice); + // reset the allocation so it can be freed + if this.buf.is_empty() { + *this.buf = BytesMut::new(); + } + Poll::Ready(Ok(())) } } @@ -316,7 +355,7 @@ mod tests { #[tokio::test] async fn test_ipv4() { let header = super::HEADER - // Proxy command, Inet << 4 | Stream + // Proxy command, IPV4 | TCP .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice()) // 12 + 3 bytes .chain([0, 15].as_slice()) @@ -345,6 +384,40 @@ mod tests { ); } + #[tokio::test] + async fn test_ipv6() { + let header = super::HEADER + // Proxy command, IPV6 | UDP + .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice()) + // 36 + 3 bytes + .chain([0, 39].as_slice()) + // src ip + .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice()) + // dst ip + .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice()) + // src port + .chain([1, 1].as_slice()) + // dst port + .chain([255, 255].as_slice()) + // TLV + .chain([1, 2, 3].as_slice()); + + let extra_data = [0x55; 256]; + + let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice()))); + + let mut bytes = vec![]; + read.read_to_end(&mut bytes).await.unwrap(); + + assert_eq!(bytes, extra_data); + assert_eq!( + read.state, + ProxyParse::Finished( + ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into() + ) + ); + } + #[tokio::test] async fn test_invalid() { let data = [0x55; 256]; @@ -371,7 +444,7 @@ mod tests { #[tokio::test] async fn test_large_tlv() { - let tlv = [0x55; 512]; + let tlv = vec![0x55; 32768]; let len = (12 + tlv.len() as u16).to_be_bytes(); let header = super::HEADER diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 8afddd7f412d..66ce2e5fd05e 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -110,7 +110,7 @@ pub async fn task_main( info!("accepted postgres client connection"); let mut socket = WithClientIp::new(socket); - if let Some(ip) = socket.wait_for_socket().await? { + if let Some(ip) = socket.wait_for_addr().await? { tracing::Span::current().record("peer_addr", &tracing::field::display(ip)); }