Skip to content

Commit

Permalink
Grow noise buffers dynamically. (#1436)
Browse files Browse the repository at this point in the history
* Grow noise buffers dynamically.

Currently we allocate a buffer of 176 KiB for each noise state, i.e.
each connection. For connections which see only small data frames
this is wasteful. At the same time we limit the max. write buffer size
to 16 KiB to keep the total buffer size relatively small, which
results in smaller encrypted messages and also makes it less likely to
ever encounter the max. noise package size of 64 KiB in practice when
communicating with other nodes using the same implementation.

This PR repaces the static buffer allocation with a dynamic one. We
only reserve a small space for the authentication tag plus some extra
reserve and are able to buffer larger data frames before encrypting.

* Grow write buffer from offset.

As suggested by @mxinden, this prevents increasing the write buffer up
to MAX_WRITE_BUF_LEN.

Co-authored-by: Pierre Krieger <[email protected]>
  • Loading branch information
twittner and tomaka authored Feb 13, 2020
1 parent bbed28b commit 70d634d
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 106 deletions.
2 changes: 1 addition & 1 deletion protocols/noise/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ log = "0.4"
prost = "0.6.1"
rand = "0.7.2"
sha2 = "0.8.0"
static_assertions = "1"
x25519-dalek = "0.5"
zeroize = "1"

Expand All @@ -25,7 +26,6 @@ snow = { version = "0.6.1", features = ["ring-resolver"], default-features = fal
[target.'cfg(target_os = "unknown")'.dependencies]
snow = { version = "0.6.1", features = ["default-resolver"], default-features = false }


[dev-dependencies]
env_logger = "0.7.1"
libp2p-tcp = { version = "0.15.0", path = "../../transports/tcp" }
Expand Down
145 changes: 58 additions & 87 deletions protocols/noise/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,17 @@ use futures::ready;
use futures::prelude::*;
use log::{debug, trace};
use snow;
use std::{fmt, io, pin::Pin, ops::DerefMut, task::{Context, Poll}};
use std::{cmp::min, fmt, io, pin::Pin, ops::DerefMut, task::{Context, Poll}};

/// Max. size of a noise package.
const MAX_NOISE_PKG_LEN: usize = 65535;
const MAX_WRITE_BUF_LEN: usize = 16384;
const TOTAL_BUFFER_LEN: usize = 2 * MAX_NOISE_PKG_LEN + 3 * MAX_WRITE_BUF_LEN;
/// Extra space given to the encryption buffer to hold key material.
const EXTRA_ENCRYPT_SPACE: usize = 1024;
/// Max. output buffer size before forcing a flush.
const MAX_WRITE_BUF_LEN: usize = MAX_NOISE_PKG_LEN - EXTRA_ENCRYPT_SPACE;

/// A single `Buffer` contains multiple non-overlapping byte buffers.
struct Buffer {
inner: Box<[u8; TOTAL_BUFFER_LEN]>
}

/// A mutable borrow of all byte buffers, backed by `Buffer`.
struct BufferBorrow<'a> {
read: &'a mut [u8],
read_crypto: &'a mut [u8],
write: &'a mut [u8],
write_crypto: &'a mut [u8]
}

impl Buffer {
/// Create a mutable borrow by splitting the buffer slice.
fn borrow_mut(&mut self) -> BufferBorrow<'_> {
let (r, w) = self.inner.split_at_mut(2 * MAX_NOISE_PKG_LEN);
let (read, read_crypto) = r.split_at_mut(MAX_NOISE_PKG_LEN);
let (write, write_crypto) = w.split_at_mut(MAX_WRITE_BUF_LEN);
BufferBorrow { read, read_crypto, write, write_crypto }
}
static_assertions::const_assert! {
MAX_WRITE_BUF_LEN + EXTRA_ENCRYPT_SPACE <= MAX_NOISE_PKG_LEN
}

/// A passthrough enum for the two kinds of state machines in `snow`
Expand Down Expand Up @@ -97,9 +81,12 @@ impl SnowState {
pub struct NoiseOutput<T> {
io: T,
session: SnowState,
buffer: Buffer,
read_state: ReadState,
write_state: WriteState
write_state: WriteState,
read_buffer: Vec<u8>,
write_buffer: Vec<u8>,
decrypt_buffer: Vec<u8>,
encrypt_buffer: Vec<u8>
}

impl<T> fmt::Debug for NoiseOutput<T> {
Expand All @@ -116,9 +103,12 @@ impl<T> NoiseOutput<T> {
NoiseOutput {
io,
session,
buffer: Buffer { inner: Box::new([0; TOTAL_BUFFER_LEN]) },
read_state: ReadState::Init,
write_state: WriteState::Init
write_state: WriteState::Init,
read_buffer: Vec::new(),
write_buffer: Vec::new(),
decrypt_buffer: Vec::new(),
encrypt_buffer: Vec::new()
}
}
}
Expand Down Expand Up @@ -159,15 +149,8 @@ enum WriteState {
}

impl<T: AsyncRead + Unpin> AsyncRead for NoiseOutput<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, std::io::Error>> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let mut this = self.deref_mut();

let buffer = this.buffer.borrow_mut();

loop {
trace!("read state: {:?}", this.read_state);
match this.read_state {
Expand All @@ -187,7 +170,6 @@ impl<T: AsyncRead + Unpin> AsyncRead for NoiseOutput<T> {
}
Poll::Pending => {
this.read_state = ReadState::ReadLen { buf, off };

return Poll::Pending;
}
};
Expand All @@ -197,30 +179,28 @@ impl<T: AsyncRead + Unpin> AsyncRead for NoiseOutput<T> {
this.read_state = ReadState::Init;
continue
}
this.read_buffer.resize(usize::from(n), 0u8);
this.read_state = ReadState::ReadData { len: usize::from(n), off: 0 }
}
ReadState::ReadData { len, ref mut off } => {
let n = match ready!(
Pin::new(&mut this.io).poll_read(cx, &mut buffer.read[*off ..len])
) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
let n = {
let f = Pin::new(&mut this.io).poll_read(cx, &mut this.read_buffer[*off .. len]);
match ready!(f) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
}
};

trace!("read: read {}/{} bytes", *off + n, len);
if n == 0 {
trace!("read: eof");
this.read_state = ReadState::Eof(Err(()));
return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()))
}

*off += n;
if len == *off {
trace!("read: decrypting {} bytes", len);
if let Ok(n) = this.session.read_message(
&buffer.read[.. len],
buffer.read_crypto
){
this.decrypt_buffer.resize(len, 0u8);
if let Ok(n) = this.session.read_message(&this.read_buffer, &mut this.decrypt_buffer) {
trace!("read: payload len = {} bytes", n);
this.read_state = ReadState::CopyData { len: n, off: 0 }
} else {
Expand All @@ -231,8 +211,8 @@ impl<T: AsyncRead + Unpin> AsyncRead for NoiseOutput<T> {
}
}
ReadState::CopyData { len, ref mut off } => {
let n = std::cmp::min(len - *off, buf.len());
buf[.. n].copy_from_slice(&buffer.read_crypto[*off .. *off + n]);
let n = min(len - *off, buf.len());
buf[.. n].copy_from_slice(&this.decrypt_buffer[*off .. *off + n]);
trace!("read: copied {}/{} bytes", *off + n, len);
*off += n;
if len == *off {
Expand All @@ -255,29 +235,25 @@ impl<T: AsyncRead + Unpin> AsyncRead for NoiseOutput<T> {
}

impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>>{
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let mut this = self.deref_mut();

let buffer = this.buffer.borrow_mut();

loop {
trace!("write state: {:?}", this.write_state);
match this.write_state {
WriteState::Init => {
this.write_state = WriteState::BufferData { off: 0 }
}
WriteState::BufferData { ref mut off } => {
let n = std::cmp::min(MAX_WRITE_BUF_LEN - *off, buf.len());
buffer.write[*off .. *off + n].copy_from_slice(&buf[.. n]);
let n = min(MAX_WRITE_BUF_LEN, off.saturating_add(buf.len()));
this.write_buffer.resize(n, 0u8);
let n = min(MAX_WRITE_BUF_LEN - *off, buf.len());
this.write_buffer[*off .. *off + n].copy_from_slice(&buf[.. n]);
trace!("write: buffered {} bytes", *off + n);
*off += n;
if *off == MAX_WRITE_BUF_LEN {
trace!("write: encrypting {} bytes", *off);
match this.session.write_message(buffer.write, buffer.write_crypto) {
this.encrypt_buffer.resize(MAX_WRITE_BUF_LEN + EXTRA_ENCRYPT_SPACE, 0u8);
match this.session.write_message(&this.write_buffer, &mut this.encrypt_buffer) {
Ok(n) => {
trace!("write: cipher text len = {} bytes", n);
this.write_state = WriteState::WriteLen {
Expand Down Expand Up @@ -316,11 +292,12 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
this.write_state = WriteState::WriteData { len, off: 0 }
}
WriteState::WriteData { len, ref mut off } => {
let n = match ready!(
Pin::new(&mut this.io).poll_write(cx, &buffer.write_crypto[*off .. len])
) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
let n = {
let f = Pin::new(&mut this.io).poll_write(cx, &this.encrypt_buffer[*off .. len]);
match ready!(f) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e))
}
};
trace!("write: wrote {}/{} bytes", *off + n, len);
if n == 0 {
Expand All @@ -343,20 +320,17 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
}
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Result<(), std::io::Error>> {
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let mut this = self.deref_mut();

let buffer = this.buffer.borrow_mut();

loop {
match this.write_state {
WriteState::Init => return Pin::new(&mut this.io).poll_flush(cx),
WriteState::Init => {
return Pin::new(&mut this.io).poll_flush(cx)
}
WriteState::BufferData { off } => {
trace!("flush: encrypting {} bytes", off);
match this.session.write_message(&buffer.write[.. off], buffer.write_crypto) {
this.encrypt_buffer.resize(off + EXTRA_ENCRYPT_SPACE, 0u8);
match this.session.write_message(&this.write_buffer[.. off], &mut this.encrypt_buffer) {
Ok(n) => {
trace!("flush: cipher text len = {} bytes", n);
this.write_state = WriteState::WriteLen {
Expand Down Expand Up @@ -386,18 +360,18 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
}
Poll::Pending => {
this.write_state = WriteState::WriteLen { len, buf, off };

return Poll::Pending
}
}
this.write_state = WriteState::WriteData { len, off: 0 }
}
WriteState::WriteData { len, ref mut off } => {
let n = match ready!(
Pin::new(&mut this.io).poll_write(cx, &buffer.write_crypto[*off .. len])
) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
let n = {
let f = Pin::new(&mut this.io).poll_write(cx, &this.encrypt_buffer[*off .. len]);
match ready!(f) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
}
};
trace!("flush: wrote {}/{} bytes", *off + n, len);
if n == 0 {
Expand All @@ -420,10 +394,7 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
}
}

fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>>{
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>>{
ready!(self.as_mut().poll_flush(cx))?;
Pin::new(&mut self.io).poll_close(cx)
}
Expand All @@ -443,7 +414,7 @@ fn read_frame_len<R: AsyncRead + Unpin>(
cx: &mut Context<'_>,
buf: &mut [u8; 2],
off: &mut usize,
) -> Poll<Result<Option<u16>, std::io::Error>> {
) -> Poll<io::Result<Option<u16>>> {
loop {
match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off ..])) {
Ok(n) => {
Expand Down Expand Up @@ -476,7 +447,7 @@ fn write_frame_len<W: AsyncWrite + Unpin>(
cx: &mut Context<'_>,
buf: &[u8; 2],
off: &mut usize,
) -> Poll<Result<bool, std::io::Error>> {
) -> Poll<io::Result<bool>> {
loop {
match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off ..])) {
Ok(n) => {
Expand Down
Loading

0 comments on commit 70d634d

Please sign in to comment.