Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
conradludgate committed Aug 31, 2023
1 parent 6f1a18b commit f4b5622
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 26 deletions.
2 changes: 1 addition & 1 deletion proxy/src/http/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ pub async fn task_main(
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
let (io, tls) = stream.get_ref();
let peer_addr = io.client_socket().unwrap_or(io.inner.remote_addr());
let peer_addr = io.client_addr().unwrap_or(io.inner.remote_addr());
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();

Expand Down
121 changes: 97 additions & 24 deletions proxy/src/protocol2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pin_project! {
#[pin]
pub inner: T,
buf: BytesMut,
tlv_bytes: usize,
tlv_bytes: u16,
state: ProxyParse,
}
}
Expand Down Expand Up @@ -82,7 +82,7 @@ impl<T> WithClientIp<T> {
}
}

pub fn client_socket(&self) -> Option<SocketAddr> {
pub fn client_addr(&self) -> Option<SocketAddr> {
match self.state {
ProxyParse::Finished(socket) => Some(socket),
_ => None,
Expand All @@ -91,9 +91,20 @@ impl<T> WithClientIp<T> {
}

impl<T: AsyncRead + Unpin> WithClientIp<T> {
pub async fn wait_for_socket(&mut self) -> io::Result<Option<SocketAddr>> {
let mut pin = Pin::new(self);
poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await
pub async fn wait_for_addr(&mut self) -> io::Result<Option<SocketAddr>> {
match self.state {
ProxyParse::NotStarted => {
let mut pin = Pin::new(&mut *self);
let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?;
match addr {
Some(addr) => self.state = ProxyParse::Finished(addr),
None => self.state = ProxyParse::None,
}
Ok(addr)
}
ProxyParse::Finished(addr) => Ok(Some(addr)),
ProxyParse::None => Ok(None),
}
}
}

Expand All @@ -111,7 +122,7 @@ impl<T: AsyncRead> WithClientIp<T> {
) -> Poll<io::Result<Option<SocketAddr>>> {
// The binary header format starts with a constant 12 bytes block containing the protocol signature :
// \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A
while self.buf.len() <= 16 {
while self.buf.len() < 16 {
let mut this = self.as_mut().project();
let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?;

Expand All @@ -134,7 +145,10 @@ impl<T: AsyncRead> WithClientIp<T> {
let version = vc >> 4;
let command = vc & 0b1111;
if version != 2 {
return Poll::Ready(Ok(None));
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"invalid proxy protocol version. expected version 2",
)));
}
match command {
// the connection was established on purpose by the proxy
Expand All @@ -150,7 +164,12 @@ impl<T: AsyncRead> WithClientIp<T> {
1 => {}
// other values are unassigned and must not be emitted by senders. Receivers
// must drop connections presenting unexpected values here.
_ => return Poll::Ready(Ok(None)),
_ => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"invalid proxy protocol command. expected local (0) or proxy (1)",
)))
}
};

// The 14th byte contains the transport protocol and address family. The highest 4
Expand Down Expand Up @@ -180,15 +199,21 @@ impl<T: AsyncRead> WithClientIp<T> {
// of bytes and must not assume zero is presented for LOCAL connections. When a
// receiver accepts an incoming connection showing an UNSPEC address family or
// protocol, it may or may not decide to log the address information if present.
let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap()) as usize;
let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap());
if remaining_length < address_length {
return Poll::Ready(Ok(None));
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"invalid proxy protocol length. not enough to fit requested IP addresses",
)));
}

while self.buf.len() < 16 + address_length {
while self.buf.len() < 16 + address_length as usize {
let mut this = self.as_mut().project();
if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 {
return Poll::Ready(Ok(None));
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"stream closed while waiting for proxy protocol addresses",
)));
}
}

Expand All @@ -204,7 +229,7 @@ impl<T: AsyncRead> WithClientIp<T> {
// - destination layer 3 address in network byte order
// - source layer 4 address if any, in network byte order (port)
// - destination layer 4 address if any, in network byte order (port)
let addresses = this.buf.split_to(address_length);
let addresses = this.buf.split_to(address_length as usize);
let socket = match address_length {
12 => {
let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
Expand All @@ -220,9 +245,7 @@ impl<T: AsyncRead> WithClientIp<T> {
};

*this.tlv_bytes = remaining_length - address_length;
let discard = usize::min(*this.tlv_bytes, this.buf.len());
*this.tlv_bytes -= discard;
this.buf.advance(discard);
self.as_mut().skip_tlv_inner();

Poll::Ready(Ok(socket))
}
Expand All @@ -238,18 +261,29 @@ impl<T: AsyncRead> WithClientIp<T> {
}

#[cold]
fn skip_tlv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut this = self.project();
fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut this = self.as_mut().project();
// we know that this.buf is empty
debug_assert_eq!(this.buf.len(), 0);

let n = ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?);
let tlv_bytes_read = usize::min(n, *this.tlv_bytes);
*this.tlv_bytes -= tlv_bytes_read;
this.buf.advance(tlv_bytes_read);
this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize);
ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?);
self.skip_tlv_inner();

Poll::Ready(Ok(()))
}

fn skip_tlv_inner(self: Pin<&mut Self>) {
let tlv_bytes_read = match u16::try_from(self.buf.len()) {
// we read more than u16::MAX therefore we must have read the full tlv_bytes
Err(_) => self.tlv_bytes,
// we might not have read the full tlv bytes yet
Ok(n) => u16::min(n, self.tlv_bytes),
};
let this = self.project();
*this.tlv_bytes -= tlv_bytes_read;
this.buf.advance(tlv_bytes_read as usize);
}
}

impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
Expand Down Expand Up @@ -282,6 +316,11 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
let slice = this.buf.split_to(write).freeze();
buf.put_slice(&slice);

// reset the allocation so it can be freed
if this.buf.is_empty() {
*this.buf = BytesMut::new();
}

Poll::Ready(Ok(()))
}
}
Expand Down Expand Up @@ -316,7 +355,7 @@ mod tests {
#[tokio::test]
async fn test_ipv4() {
let header = super::HEADER
// Proxy command, Inet << 4 | Stream
// Proxy command, IPV4 | TCP
.chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
// 12 + 3 bytes
.chain([0, 15].as_slice())
Expand Down Expand Up @@ -345,6 +384,40 @@ mod tests {
);
}

#[tokio::test]
async fn test_ipv6() {
let header = super::HEADER
// Proxy command, IPV6 | UDP
.chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
// 36 + 3 bytes
.chain([0, 39].as_slice())
// src ip
.chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
// dst ip
.chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
// src port
.chain([1, 1].as_slice())
// dst port
.chain([255, 255].as_slice())
// TLV
.chain([1, 2, 3].as_slice());

let extra_data = [0x55; 256];

let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));

let mut bytes = vec![];
read.read_to_end(&mut bytes).await.unwrap();

assert_eq!(bytes, extra_data);
assert_eq!(
read.state,
ProxyParse::Finished(
([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
)
);
}

#[tokio::test]
async fn test_invalid() {
let data = [0x55; 256];
Expand All @@ -371,7 +444,7 @@ mod tests {

#[tokio::test]
async fn test_large_tlv() {
let tlv = [0x55; 512];
let tlv = vec![0x55; 32768];
let len = (12 + tlv.len() as u16).to_be_bytes();

let header = super::HEADER
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ pub async fn task_main(
info!("accepted postgres client connection");

let mut socket = WithClientIp::new(socket);
if let Some(ip) = socket.wait_for_socket().await? {
if let Some(ip) = socket.wait_for_addr().await? {
tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
}

Expand Down

0 comments on commit f4b5622

Please sign in to comment.