Skip to content

Commit

Permalink
fix: add gracefull shutdown ensure all data received from the quinn s…
Browse files Browse the repository at this point in the history
…tack
  • Loading branch information
fabian1409 authored and 0xThemis committed Oct 15, 2024
1 parent b4757d2 commit a9cbcbf
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 73 deletions.
2 changes: 0 additions & 2 deletions mpc-core/src/protocols/bridges/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ impl RepToShamirNetwork<ShamirMpcNet> for Rep3MpcNet {
fn to_shamir_net(self) -> ShamirMpcNet {
let Self {
id,
runtime,
net_handler,
chan_next,
chan_prev,
Expand All @@ -32,7 +31,6 @@ impl RepToShamirNetwork<ShamirMpcNet> for Rep3MpcNet {
ShamirMpcNet {
id: id.into(),
num_parties: 3,
runtime,
net_handler,
channels,
}
Expand Down
44 changes: 21 additions & 23 deletions mpc-core/src/protocols/rep3/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use ark_ff::PrimeField;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use bytes::{Bytes, BytesMut};
use eyre::{bail, eyre, Report};
use mpc_net::{channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler};
use mpc_net::{
channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler, MpcNetworkHandlerWrapper,
};

use super::{
id::PartyID,
Expand Down Expand Up @@ -214,8 +216,7 @@ pub trait Rep3Network: Send {
#[derive(Debug)]
pub struct Rep3MpcNet {
pub(crate) id: PartyID,
pub(crate) runtime: Arc<tokio::runtime::Runtime>,
pub(crate) net_handler: Arc<MpcNetworkHandler>,
pub(crate) net_handler: Arc<MpcNetworkHandlerWrapper>,
pub(crate) chan_next: ChannelHandle<Bytes, BytesMut>,
pub(crate) chan_prev: ChannelHandle<Bytes, BytesMut>,
}
Expand Down Expand Up @@ -249,28 +250,27 @@ impl Rep3MpcNet {
})?;
Ok(Self {
id,
runtime: Arc::new(runtime),
net_handler: Arc::new(net_handler),
net_handler: Arc::new(MpcNetworkHandlerWrapper::new(runtime, net_handler)),
chan_next,
chan_prev,
})
}

/// Shuts down the network interface.
pub fn shutdown(self) {
let Self {
id: _,
runtime,
net_handler,
chan_next,
chan_prev,
} = self;
drop(chan_next);
drop(chan_prev);
if let Some(net_handler) = Arc::into_inner(net_handler) {
runtime.block_on(net_handler.shutdown());
}
}
// pub fn shutdown(self) {
// let Self {
// id: _,
// runtime,
// net_handler,
// chan_next,
// chan_prev,
// } = self;
// drop(chan_next);
// drop(chan_prev);
// if let Some(net_handler) = Arc::into_inner(net_handler) {
// runtime.block_on(net_handler.shutdown());
// }
// }

/// Sends bytes over the network to the target party.
pub fn send_bytes(&mut self, target: PartyID, data: Bytes) -> std::io::Result<()> {
Expand Down Expand Up @@ -355,9 +355,8 @@ impl Rep3Network for Rep3MpcNet {
fn fork(&mut self) -> std::io::Result<Self> {
let id = self.id;
let net_handler = Arc::clone(&self.net_handler);
let runtime = Arc::clone(&self.runtime);
let (chan_next, chan_prev) = runtime.block_on(async {
let mut channels = net_handler.get_byte_channels().await?;
let (chan_next, chan_prev) = net_handler.runtime.block_on(async {
let mut channels = net_handler.inner.get_byte_channels().await?;

let chan_next = channels
.remove(&id.next_id().into())
Expand All @@ -376,7 +375,6 @@ impl Rep3Network for Rep3MpcNet {

Ok(Self {
id,
runtime,
net_handler,
chan_next,
chan_prev,
Expand Down
49 changes: 23 additions & 26 deletions mpc-core/src/protocols/shamir/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use bytes::{Bytes, BytesMut};
use eyre::{bail, eyre, Report};
use mpc_net::{channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler};
use mpc_net::{
channel::ChannelHandle, config::NetworkConfig, MpcNetworkHandler, MpcNetworkHandlerWrapper,
};
use std::{collections::HashMap, sync::Arc};

/// This trait defines the network interface for the Shamir protocol.
Expand Down Expand Up @@ -75,8 +77,7 @@ pub trait ShamirNetwork: Send {
pub struct ShamirMpcNet {
pub(crate) id: usize, // 0 <= id < num_parties
pub(crate) num_parties: usize,
pub(crate) runtime: Arc<tokio::runtime::Runtime>,
pub(crate) net_handler: Arc<MpcNetworkHandler>,
pub(crate) net_handler: Arc<MpcNetworkHandlerWrapper>,
pub(crate) channels: HashMap<usize, ChannelHandle<Bytes, BytesMut>>,
}

Expand Down Expand Up @@ -120,30 +121,28 @@ impl ShamirMpcNet {
Ok(Self {
id,
num_parties,
runtime: Arc::new(runtime),
net_handler: Arc::new(net_handler),
net_handler: Arc::new(MpcNetworkHandlerWrapper::new(runtime, net_handler)),
channels,
})
}

/// Shuts down the network interface.
pub fn shutdown(self) {
let Self {
id: _,
num_parties: _,
runtime,
net_handler,
channels,
} = self;
for chan in channels.into_iter() {
drop(chan);
}
if let Some(net_handler) = Arc::into_inner(net_handler) {
runtime.block_on(async {
let _ = net_handler.shutdown().await;
});
}
}
// pub fn shutdown(self) {
// let Self {
// id: _,
// num_parties: _,
// net_handler,
// channels,
// } = self;
// for chan in channels.into_iter() {
// drop(chan);
// }
// if let Some(net_handler) = Arc::into_inner(net_handler) {
// runtime.block_on(async {
// net_handler.shutdown().await;
// });
// }
// }

/// Sends bytes over the network to the target party.
pub fn send_bytes(&mut self, target: usize, data: Bytes) -> std::io::Result<()> {
Expand Down Expand Up @@ -284,9 +283,8 @@ impl ShamirNetwork for ShamirMpcNet {
let id = self.id;
let num_parties = self.num_parties;
let net_handler = Arc::clone(&self.net_handler);
let runtime = Arc::clone(&self.runtime);
let channels = runtime.block_on(async {
let mut channels = net_handler.get_byte_channels().await?;
let channels = net_handler.runtime.block_on(async {
let mut channels = net_handler.inner.get_byte_channels().await?;

let mut channels_ = HashMap::with_capacity(num_parties - 1);

Expand All @@ -307,7 +305,6 @@ impl ShamirNetwork for ShamirMpcNet {
Ok(Self {
id,
num_parties,
runtime,
net_handler,
channels,
})
Expand Down
39 changes: 22 additions & 17 deletions mpc-net/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,26 +188,30 @@ where
buffer: &mut VecDeque<Result<MRecv, io::Error>>,
read_recv: &mut mpsc::Receiver<ReadJob<MRecv>>,
frame_reader: &mut FramedRead<R, C>,
) where
) -> bool
where
C: 'static,
R: AsyncReadExt + Unpin + 'static,
FramedRead<R, C>: Stream<Item = Result<MRecv, io::Error>> + Send,
{
//we did not get a job so far so just put into buffer
//also if we get None we maybe need to close everything but for now just this
let read_result = match frame {
None => Err(io::Error::new(io::ErrorKind::BrokenPipe, "closed pipe")),
Some(res) => res,
};
if buffer.len() >= READ_BUFFER_SIZE {
//wait for a read job as buffer is full
if let Some(read_job) = read_recv.recv().await {
Self::handle_read_job(read_job, buffer, frame_reader).await;
} else {
tracing::warn!("still have frames in buffer but channel dropped?");
// we did not get a job so far so just put into buffer
// if the frame is None (either because the other party is done, or the connection was closed) we return false and stop read task
if let Some(read_result) = frame {
if buffer.len() >= READ_BUFFER_SIZE {
//wait for a read job as buffer is full
if let Some(read_job) = read_recv.recv().await {
Self::handle_read_job(read_job, buffer, frame_reader).await;
} else {
tracing::warn!(
"[handel_read_frame] still have frames in buffer but channel dropped?"
);
}
}
buffer.push_back(read_result);
true
} else {
false
}
buffer.push_back(read_result);
}

/// Create a new [`ChannelHandle`] from a [`Channel`]. This spawns a new tokio task that handles the read and write jobs so they can happen concurrently.
Expand Down Expand Up @@ -239,9 +243,10 @@ where
//futures::stream::StreamExt::next on any Stream is cancellation safe but also
//when using quinn? Should be...
frame = read.next() => {
//if this method returns true we break
//this happens when the read job channel dropped
Self::handle_read_frame(frame, &mut buffer, &mut read_recv, &mut read).await
//if this method returns false we break
if !Self::handle_read_frame(frame, &mut buffer, &mut read_recv, &mut read).await {
break;
}
}
}
}
Expand Down
60 changes: 55 additions & 5 deletions mpc-net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,40 @@ use quinn::{
VarInt,
};
use serde::{de::DeserializeOwned, Serialize};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
runtime::Runtime,
};
use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};

pub mod channel;
pub mod codecs;
pub mod config;

/// A warapper for a runtime and a network handler for MPC protocols.
/// Ensures a gracefull shutdown on drop
#[derive(Debug)]
pub struct MpcNetworkHandlerWrapper {
/// The runtime used by the network handler
pub runtime: Runtime,
/// The wrapped network handler
pub inner: MpcNetworkHandler,
}

impl MpcNetworkHandlerWrapper {
/// Create a new wrapper
pub fn new(runtime: Runtime, inner: MpcNetworkHandler) -> Self {
Self { runtime, inner }
}
}

impl Drop for MpcNetworkHandlerWrapper {
fn drop(&mut self) {
// ignore errors in drop
let _ = self.runtime.block_on(self.inner.shutdown());
}
}

/// A network handler for MPC protocols.
#[derive(Debug)]
pub struct MpcNetworkHandler {
Expand Down Expand Up @@ -246,13 +273,36 @@ impl MpcNetworkHandler {
}

/// Shutdown all connections, and call [`quinn::Endpoint::wait_idle`] on all of them
pub async fn shutdown(self) {
for conn in self.connections.into_values() {
conn.close(0u32.into(), b"");
pub async fn shutdown(&self) -> std::io::Result<()> {
tracing::debug!(
"party {} shutting down, conns = {:?}",
self.my_id,
self.connections.keys()
);

for (id, conn) in self.connections.iter() {
if self.my_id < *id {
let mut send = conn.open_uni().await?;
send.write_all(b"done").await?;
} else {
let mut recv = conn.accept_uni().await?;
let mut buffer = vec![0u8; b"done".len()];
recv.read_exact(&mut buffer).await.map_err(|_| {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "failed to recv done msg")
})?;

tracing::debug!("party {} closing conn = {id}", self.my_id);

conn.close(
0u32.into(),
format!("close from party {}", self.my_id).as_bytes(),
);
}
}
for endpoint in self.endpoints {
for endpoint in self.endpoints.iter() {
endpoint.wait_idle().await;
endpoint.close(VarInt::from_u32(0), &[]);
}
Ok(())
}
}

0 comments on commit a9cbcbf

Please sign in to comment.