diff --git a/protocols/noise/Cargo.toml b/protocols/noise/Cargo.toml index 18fbb0052f5..5733a360410 100644 --- a/protocols/noise/Cargo.toml +++ b/protocols/noise/Cargo.toml @@ -16,6 +16,7 @@ log = "0.4" prost = "0.6.1" rand = "0.7.2" sha2 = "0.8.0" +static_assertions = "1" x25519-dalek = "0.5" zeroize = "1" @@ -25,7 +26,6 @@ snow = { version = "0.6.1", features = ["ring-resolver"], default-features = fal [target.'cfg(target_os = "unknown")'.dependencies] snow = { version = "0.6.1", features = ["default-resolver"], default-features = false } - [dev-dependencies] env_logger = "0.7.1" libp2p-tcp = { version = "0.15.0", path = "../../transports/tcp" } diff --git a/protocols/noise/src/io.rs b/protocols/noise/src/io.rs index 5318d6aec34..d3900a5c2ca 100644 --- a/protocols/noise/src/io.rs +++ b/protocols/noise/src/io.rs @@ -26,33 +26,17 @@ use futures::ready; use futures::prelude::*; use log::{debug, trace}; use snow; -use std::{fmt, io, pin::Pin, ops::DerefMut, task::{Context, Poll}}; +use std::{cmp::min, fmt, io, pin::Pin, ops::DerefMut, task::{Context, Poll}}; +/// Max. size of a noise package. const MAX_NOISE_PKG_LEN: usize = 65535; -const MAX_WRITE_BUF_LEN: usize = 16384; -const TOTAL_BUFFER_LEN: usize = 2 * MAX_NOISE_PKG_LEN + 3 * MAX_WRITE_BUF_LEN; +/// Extra space given to the encryption buffer to hold key material. +const EXTRA_ENCRYPT_SPACE: usize = 1024; +/// Max. output buffer size before forcing a flush. +const MAX_WRITE_BUF_LEN: usize = MAX_NOISE_PKG_LEN - EXTRA_ENCRYPT_SPACE; -/// A single `Buffer` contains multiple non-overlapping byte buffers. -struct Buffer { - inner: Box<[u8; TOTAL_BUFFER_LEN]> -} - -/// A mutable borrow of all byte buffers, backed by `Buffer`. -struct BufferBorrow<'a> { - read: &'a mut [u8], - read_crypto: &'a mut [u8], - write: &'a mut [u8], - write_crypto: &'a mut [u8] -} - -impl Buffer { - /// Create a mutable borrow by splitting the buffer slice. - fn borrow_mut(&mut self) -> BufferBorrow<'_> { - let (r, w) = self.inner.split_at_mut(2 * MAX_NOISE_PKG_LEN); - let (read, read_crypto) = r.split_at_mut(MAX_NOISE_PKG_LEN); - let (write, write_crypto) = w.split_at_mut(MAX_WRITE_BUF_LEN); - BufferBorrow { read, read_crypto, write, write_crypto } - } +static_assertions::const_assert! { + MAX_WRITE_BUF_LEN + EXTRA_ENCRYPT_SPACE <= MAX_NOISE_PKG_LEN } /// A passthrough enum for the two kinds of state machines in `snow` @@ -97,9 +81,12 @@ impl SnowState { pub struct NoiseOutput { io: T, session: SnowState, - buffer: Buffer, read_state: ReadState, - write_state: WriteState + write_state: WriteState, + read_buffer: Vec, + write_buffer: Vec, + decrypt_buffer: Vec, + encrypt_buffer: Vec } impl fmt::Debug for NoiseOutput { @@ -116,9 +103,12 @@ impl NoiseOutput { NoiseOutput { io, session, - buffer: Buffer { inner: Box::new([0; TOTAL_BUFFER_LEN]) }, read_state: ReadState::Init, - write_state: WriteState::Init + write_state: WriteState::Init, + read_buffer: Vec::new(), + write_buffer: Vec::new(), + decrypt_buffer: Vec::new(), + encrypt_buffer: Vec::new() } } } @@ -159,15 +149,8 @@ enum WriteState { } impl AsyncRead for NoiseOutput { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { let mut this = self.deref_mut(); - - let buffer = this.buffer.borrow_mut(); - loop { trace!("read state: {:?}", this.read_state); match this.read_state { @@ -187,7 +170,6 @@ impl AsyncRead for NoiseOutput { } Poll::Pending => { this.read_state = ReadState::ReadLen { buf, off }; - return Poll::Pending; } }; @@ -197,30 +179,28 @@ impl AsyncRead for NoiseOutput { this.read_state = ReadState::Init; continue } + this.read_buffer.resize(usize::from(n), 0u8); this.read_state = ReadState::ReadData { len: usize::from(n), off: 0 } } ReadState::ReadData { len, ref mut off } => { - let n = match ready!( - Pin::new(&mut this.io).poll_read(cx, &mut buffer.read[*off ..len]) - ) { - Ok(n) => n, - Err(e) => return Poll::Ready(Err(e)), + let n = { + let f = Pin::new(&mut this.io).poll_read(cx, &mut this.read_buffer[*off .. len]); + match ready!(f) { + Ok(n) => n, + Err(e) => return Poll::Ready(Err(e)), + } }; - trace!("read: read {}/{} bytes", *off + n, len); if n == 0 { trace!("read: eof"); this.read_state = ReadState::Eof(Err(())); return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())) } - *off += n; if len == *off { trace!("read: decrypting {} bytes", len); - if let Ok(n) = this.session.read_message( - &buffer.read[.. len], - buffer.read_crypto - ){ + this.decrypt_buffer.resize(len, 0u8); + if let Ok(n) = this.session.read_message(&this.read_buffer, &mut this.decrypt_buffer) { trace!("read: payload len = {} bytes", n); this.read_state = ReadState::CopyData { len: n, off: 0 } } else { @@ -231,8 +211,8 @@ impl AsyncRead for NoiseOutput { } } ReadState::CopyData { len, ref mut off } => { - let n = std::cmp::min(len - *off, buf.len()); - buf[.. n].copy_from_slice(&buffer.read_crypto[*off .. *off + n]); + let n = min(len - *off, buf.len()); + buf[.. n].copy_from_slice(&this.decrypt_buffer[*off .. *off + n]); trace!("read: copied {}/{} bytes", *off + n, len); *off += n; if len == *off { @@ -255,15 +235,8 @@ impl AsyncRead for NoiseOutput { } impl AsyncWrite for NoiseOutput { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll>{ + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let mut this = self.deref_mut(); - - let buffer = this.buffer.borrow_mut(); - loop { trace!("write state: {:?}", this.write_state); match this.write_state { @@ -271,13 +244,16 @@ impl AsyncWrite for NoiseOutput { this.write_state = WriteState::BufferData { off: 0 } } WriteState::BufferData { ref mut off } => { - let n = std::cmp::min(MAX_WRITE_BUF_LEN - *off, buf.len()); - buffer.write[*off .. *off + n].copy_from_slice(&buf[.. n]); + let n = min(MAX_WRITE_BUF_LEN, off.saturating_add(buf.len())); + this.write_buffer.resize(n, 0u8); + let n = min(MAX_WRITE_BUF_LEN - *off, buf.len()); + this.write_buffer[*off .. *off + n].copy_from_slice(&buf[.. n]); trace!("write: buffered {} bytes", *off + n); *off += n; if *off == MAX_WRITE_BUF_LEN { trace!("write: encrypting {} bytes", *off); - match this.session.write_message(buffer.write, buffer.write_crypto) { + this.encrypt_buffer.resize(MAX_WRITE_BUF_LEN + EXTRA_ENCRYPT_SPACE, 0u8); + match this.session.write_message(&this.write_buffer, &mut this.encrypt_buffer) { Ok(n) => { trace!("write: cipher text len = {} bytes", n); this.write_state = WriteState::WriteLen { @@ -316,11 +292,12 @@ impl AsyncWrite for NoiseOutput { this.write_state = WriteState::WriteData { len, off: 0 } } WriteState::WriteData { len, ref mut off } => { - let n = match ready!( - Pin::new(&mut this.io).poll_write(cx, &buffer.write_crypto[*off .. len]) - ) { - Ok(n) => n, - Err(e) => return Poll::Ready(Err(e)), + let n = { + let f = Pin::new(&mut this.io).poll_write(cx, &this.encrypt_buffer[*off .. len]); + match ready!(f) { + Ok(n) => n, + Err(e) => return Poll::Ready(Err(e)) + } }; trace!("write: wrote {}/{} bytes", *off + n, len); if n == 0 { @@ -343,20 +320,17 @@ impl AsyncWrite for NoiseOutput { } } - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll> { + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let mut this = self.deref_mut(); - - let buffer = this.buffer.borrow_mut(); - loop { match this.write_state { - WriteState::Init => return Pin::new(&mut this.io).poll_flush(cx), + WriteState::Init => { + return Pin::new(&mut this.io).poll_flush(cx) + } WriteState::BufferData { off } => { trace!("flush: encrypting {} bytes", off); - match this.session.write_message(&buffer.write[.. off], buffer.write_crypto) { + this.encrypt_buffer.resize(off + EXTRA_ENCRYPT_SPACE, 0u8); + match this.session.write_message(&this.write_buffer[.. off], &mut this.encrypt_buffer) { Ok(n) => { trace!("flush: cipher text len = {} bytes", n); this.write_state = WriteState::WriteLen { @@ -386,18 +360,18 @@ impl AsyncWrite for NoiseOutput { } Poll::Pending => { this.write_state = WriteState::WriteLen { len, buf, off }; - return Poll::Pending } } this.write_state = WriteState::WriteData { len, off: 0 } } WriteState::WriteData { len, ref mut off } => { - let n = match ready!( - Pin::new(&mut this.io).poll_write(cx, &buffer.write_crypto[*off .. len]) - ) { - Ok(n) => n, - Err(e) => return Poll::Ready(Err(e)), + let n = { + let f = Pin::new(&mut this.io).poll_write(cx, &this.encrypt_buffer[*off .. len]); + match ready!(f) { + Ok(n) => n, + Err(e) => return Poll::Ready(Err(e)), + } }; trace!("flush: wrote {}/{} bytes", *off + n, len); if n == 0 { @@ -420,10 +394,7 @@ impl AsyncWrite for NoiseOutput { } } - fn poll_close( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>{ + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll>{ ready!(self.as_mut().poll_flush(cx))?; Pin::new(&mut self.io).poll_close(cx) } @@ -443,7 +414,7 @@ fn read_frame_len( cx: &mut Context<'_>, buf: &mut [u8; 2], off: &mut usize, -) -> Poll, std::io::Error>> { +) -> Poll>> { loop { match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off ..])) { Ok(n) => { @@ -476,7 +447,7 @@ fn write_frame_len( cx: &mut Context<'_>, buf: &[u8; 2], off: &mut usize, -) -> Poll> { +) -> Poll> { loop { match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off ..])) { Ok(n) => { diff --git a/protocols/noise/tests/smoke.rs b/protocols/noise/tests/smoke.rs index 5fa579c2967..1ac04491a68 100644 --- a/protocols/noise/tests/smoke.rs +++ b/protocols/noise/tests/smoke.rs @@ -26,6 +26,7 @@ use libp2p_noise::{Keypair, X25519, NoiseConfig, RemoteIdentity, NoiseError, Noi use libp2p_tcp::{TcpConfig, TcpTransStream}; use log::info; use quickcheck::QuickCheck; +use std::{convert::TryInto, io}; #[allow(dead_code)] fn core_upgrade_compat() { @@ -40,7 +41,8 @@ fn core_upgrade_compat() { #[test] fn xx() { let _ = env_logger::try_init(); - fn prop(message: Vec) -> bool { + fn prop(mut messages: Vec) -> bool { + messages.truncate(5); let server_id = identity::Keypair::generate_ed25519(); let client_id = identity::Keypair::generate_ed25519(); @@ -61,16 +63,17 @@ fn xx() { }) .and_then(move |out, _| expect_identity(out, &server_id_public)); - run(server_transport, client_transport, message); + run(server_transport, client_transport, messages); true } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) + QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) } #[test] fn ix() { let _ = env_logger::try_init(); - fn prop(message: Vec) -> bool { + fn prop(mut messages: Vec) -> bool { + messages.truncate(5); let server_id = identity::Keypair::generate_ed25519(); let client_id = identity::Keypair::generate_ed25519(); @@ -91,16 +94,17 @@ fn ix() { }) .and_then(move |out, _| expect_identity(out, &server_id_public)); - run(server_transport, client_transport, message); + run(server_transport, client_transport, messages); true } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) + QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) } #[test] fn ik_xx() { let _ = env_logger::try_init(); - fn prop(message: Vec) -> bool { + fn prop(mut messages: Vec) -> bool { + messages.truncate(5); let server_id = identity::Keypair::generate_ed25519(); let server_id_public = server_id.public(); @@ -134,15 +138,15 @@ fn ik_xx() { }) .and_then(move |out, _| expect_identity(out, &server_id_public2)); - run(server_transport, client_transport, message); + run(server_transport, client_transport, messages); true } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) + QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) } type Output = (RemoteIdentity, NoiseOutput>); -fn run(server_transport: T, client_transport: U, message1: Vec) +fn run(server_transport: T, client_transport: U, messages: I) where T: Transport, T::Dial: Send + 'static, @@ -152,10 +156,9 @@ where U::Dial: Send + 'static, U::Listener: Send + 'static, U::ListenerUpgrade: Send + 'static, + I: IntoIterator + Clone { futures::executor::block_on(async { - let mut message2 = message1.clone(); - let mut server: T::Listener = server_transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); @@ -167,6 +170,7 @@ where .into_new_address() .expect("listen address"); + let outbound_msgs = messages.clone(); let client_fut = async { let mut client_session = client_transport.dial(server_address.clone()) .unwrap() @@ -174,7 +178,11 @@ where .map(|(_, session)| session) .expect("no error"); - client_session.write_all(&mut message2).await.expect("no error"); + for m in outbound_msgs { + let n = (m.0.len() as u64).to_be_bytes(); + client_session.write_all(&n[..]).await.expect("len written"); + client_session.write_all(&m.0).await.expect("no error") + } client_session.flush().await.expect("no error"); }; @@ -190,11 +198,20 @@ where .map(|(_, session)| session) .expect("no error"); - let mut server_buffer = vec![]; - info!("server: reading message"); - server_session.read_to_end(&mut server_buffer).await.expect("no error"); - - assert_eq!(server_buffer, message1); + for m in messages { + let len = { + let mut n = [0; 8]; + match server_session.read_exact(&mut n).await { + Ok(()) => u64::from_be_bytes(n), + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => 0, + Err(e) => panic!("error reading len: {}", e) + } + }; + info!("server: reading message ({} bytes)", len); + let mut server_buffer = vec![0; len.try_into().unwrap()]; + server_session.read_exact(&mut server_buffer).await.expect("no error"); + assert_eq!(server_buffer, m.0) + } }; futures::future::join(server_fut, client_fut).await; @@ -209,3 +226,15 @@ fn expect_identity(output: Output, pk: &identity::PublicKey) _ => panic!("Unexpected remote identity") } } + +#[derive(Debug, Clone, PartialEq, Eq)] +struct Message(Vec); + +impl quickcheck::Arbitrary for Message { + fn arbitrary(g: &mut G) -> Self { + let s = 1 + g.next_u32() % (128 * 1024); + let mut v = vec![0; s.try_into().unwrap()]; + g.fill_bytes(&mut v); + Message(v) + } +}