Skip to content

Commit

Permalink
Generalize Dispatch over address type
Browse files Browse the repository at this point in the history
  • Loading branch information
sgdxbc committed Apr 13, 2024
1 parent 2a17fa9 commit 36728d6
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 73 deletions.
4 changes: 2 additions & 2 deletions crates/entropy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ async fn start_peer(
// it done)
kademlia::Peer::<_, _, _, BlackHole, _>::new(
buckets,
Box::new(MessageNet::new(dispatch::Net(tcp_control_session.sender())))
Box::new(MessageNet::new(dispatch::Net::from(tcp_control_session.sender())))
as Box<dyn kademlia::Net<SocketAddr> + Send + Sync>,
// MessageNet::new(DispatchNet(Sender::from(quic_control_session.sender()))),
Sender::from(kademlia_control_session.sender()),
Expand All @@ -308,7 +308,7 @@ async fn start_peer(
),
));
let mut kademlia_control = Blanket(Buffered::from(Control::new(
Box::new(dispatch::Net(tcp_control_session.sender()))
Box::new(dispatch::Net::from(tcp_control_session.sender()))
as Box<dyn augustus::net::kademlia::Net<SocketAddr, bytes::Bytes> + Send + Sync>,
// DispatchNet(Sender::from(quic_control_session.sender())),
Box::new(Sender::from(kademlia_session.sender()))
Expand Down
12 changes: 6 additions & 6 deletions examples/bench-unreplicated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async fn main() -> anyhow::Result<()> {
erased::session::Sender::from(close_loop_session.sender()),
));
let mut state_sender = state_session.sender();
let mut tcp_control = Dispatch::<_, _, bytes::Bytes, _>::new(
let mut tcp_control = Dispatch::<_, _, _, bytes::Bytes, _>::new(
Tcp::new(None)?,
move |buf: &_| to_client_on_buf(buf, &mut state_sender),
// effectively disable connection table clean up
Expand All @@ -141,7 +141,7 @@ async fn main() -> anyhow::Result<()> {
);
} else {
let mut tcp_session = Session::new();
let raw_net = Net(tcp_session.sender());
let raw_net = Net::from(tcp_session.sender());
let mut state = Unify(Client::new(
id,
listener.local_addr()?,
Expand Down Expand Up @@ -173,7 +173,7 @@ async fn main() -> anyhow::Result<()> {
}
} else if flag_quic {
let mut quic_session = Session::new();
let raw_net = Net(quic_session.sender());
let raw_net = Net::from(quic_session.sender());
let quic = Quic::new(client_addr)?;
let mut state = Unify(Client::new(
id,
Expand Down Expand Up @@ -317,7 +317,7 @@ async fn main() -> anyhow::Result<()> {
let mut state = Unify(Replica::new(Null, ToClientMessageNet::new(simplex::Tcp)));
let mut state_session = Session::new();
let mut state_sender = state_session.sender();
let mut tcp_control = Dispatch::<_, _, bytes::Bytes, _>::new(
let mut tcp_control = Dispatch::<_, _, _, bytes::Bytes, _>::new(
Tcp::new(None)?,
move |buf: &_| to_replica_on_buf::<SocketAddr>(buf, &mut state_sender),
BlackHole,
Expand All @@ -330,7 +330,7 @@ async fn main() -> anyhow::Result<()> {
}

let mut tcp_session = Session::new();
let raw_net = Net(tcp_session.sender());
let raw_net = Net::from(tcp_session.sender());
let mut state = Unify(Replica::new(Null, ToClientMessageNet::new(raw_net)));
let mut state_session = Session::new();
let mut state_sender = state_session.sender();
Expand Down Expand Up @@ -359,7 +359,7 @@ async fn main() -> anyhow::Result<()> {

if flag_quic {
let mut quic_session = Session::new();
let raw_net = Net(quic_session.sender());
let raw_net = Net::from(quic_session.sender());
let mut state = Unify(Replica::new(Null, ToClientMessageNet::new(raw_net)));
let mut state_session = Session::new();
let mut state_sender = state_session.sender();
Expand Down
2 changes: 1 addition & 1 deletion examples/stress-connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async fn main() -> anyhow::Result<()> {
quic,
control_session.sender(),
));
let mut net = Net(control_session.sender());
let mut net = Net::from(control_session.sender());
sessions.spawn(async move { control_session.run(&mut control).await });
sessions.spawn(async move {
for j in 0..multiplier {
Expand Down
114 changes: 62 additions & 52 deletions src/net/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,33 @@
// `Dispatch`. consider the fact that kernel already always maintains a
// connection table (and yet another queuing layer), i generally don't satisfy
// with this solution
use std::{
collections::{hash_map::Entry, HashMap},
net::SocketAddr,
};
use std::collections::{hash_map::Entry, HashMap};

use derive_where::derive_where;

use tracing::{debug, warn};

use crate::event::{erased, OnEvent, OnTimer, SendEvent, SendEventOnce, Timer};

use super::{Buf, IterAddr, SendMessage};
use super::{Addr, Buf, IterAddr, SendMessage};

#[derive_where(Debug; E, P, P::Sender)]
pub struct Dispatch<E, P: Protocol<B>, B, F> {
#[derive_where(Debug; E, P, P::Sender, A)]
pub struct Dispatch<E, P: Protocol<A, B>, A, B, F> {
protocol: P,
connections: HashMap<SocketAddr, Connection<P, B>>,
connections: HashMap<A, Connection<P::Sender>>,
seq: u32,
close_sender: E,
#[derive_where(skip)]
on_buf: F,
}

#[derive_where(Debug; P::Sender)]
struct Connection<P: Protocol<B>, B> {
sender: P::Sender,
#[derive(Debug)]
struct Connection<E> {
sender: E,
seq: u32,
}

impl<E, P: Protocol<B>, B, F> Dispatch<E, P, B, F> {
impl<E, P: Protocol<A, B>, A, B, F> Dispatch<E, P, A, B, F> {
pub fn new(protocol: P, on_buf: F, close_sender: E) -> anyhow::Result<Self> {
Ok(Self {
protocol,
Expand All @@ -68,58 +65,68 @@ impl<E, P: Protocol<B>, B, F> Dispatch<E, P, B, F> {
}
}

// go with typed event for this state machine because it takes a sender that
// sends to itself and must be `Clone` at the same time
// i.e. "the horror" of type erasure
#[derive(derive_more::From)]
pub enum Event<P: Protocol<B>, B> {
pub enum Event<P: Protocol<A, B>, A, B> {
Incoming(Incoming<P::Incoming>),
Outgoing(Outgoing<B>),
Closed(Closed),
Outgoing(Outgoing<A, B>),
Closed(Closed<A>),
}

pub struct Closed(SocketAddr, u32);
pub struct Closed<A>(A, u32);

pub struct CloseGuard<E>(E, Option<SocketAddr>, u32);
pub struct CloseGuard<E, A>(E, Option<A>, u32);

impl<E: SendEventOnce<Closed>> CloseGuard<E> {
pub fn close(self, addr: SocketAddr) -> anyhow::Result<()> {
impl<E: SendEventOnce<Closed<A>>, A: Addr> CloseGuard<E, A> {
pub fn close(self, addr: A) -> anyhow::Result<()> {
if let Some(also_addr) = self.1 {
anyhow::ensure!(addr == also_addr)
}
self.0.send_once(Closed(addr, self.2))
}
}

pub trait Protocol<B> {
pub trait Protocol<A, B> {
type Sender: SendEvent<B>;

fn connect<E: SendEventOnce<Closed> + Send + 'static>(
fn connect<E: SendEventOnce<Closed<A>> + Send + 'static>(
&self,
remote: SocketAddr,
remote: A,
on_buf: impl FnMut(&[u8]) -> anyhow::Result<()> + Clone + Send + 'static,
close_guard: CloseGuard<E>,
close_guard: CloseGuard<E, A>,
) -> Self::Sender;

type Incoming;

fn accept<E: SendEventOnce<Closed> + Send + 'static>(
fn accept<E: SendEventOnce<Closed<A>> + Send + 'static>(
connection: Self::Incoming,
on_buf: impl FnMut(&[u8]) -> anyhow::Result<()> + Clone + Send + 'static,
close_guard: CloseGuard<E>,
) -> Option<(SocketAddr, Self::Sender)>;
close_guard: CloseGuard<E, A>,
) -> Option<(A, Self::Sender)>;
}

pub struct Outgoing<B>(SocketAddr, B);
pub struct Outgoing<A, B>(A, B);

#[derive(Clone)]
pub struct Net<E>(pub E);
pub struct Net<E, A>(pub E, std::marker::PhantomData<A>);
// mark address type so the following implementations not conflict

impl<E: SendEvent<Outgoing<B>>, B> SendMessage<SocketAddr, B> for Net<E> {
fn send(&mut self, dest: SocketAddr, message: B) -> anyhow::Result<()> {
impl<E, A> From<E> for Net<E, A> {
fn from(value: E) -> Self {
Self(value, Default::default())
}
}

impl<E: SendEvent<Outgoing<A, B>>, A, B> SendMessage<A, B> for Net<E, A> {
fn send(&mut self, dest: A, message: B) -> anyhow::Result<()> {
self.0.send(Outgoing(dest, message))
}
}

impl<E: SendEvent<Outgoing<B>>, B: Buf> SendMessage<IterAddr<'_, SocketAddr>, B> for Net<E> {
fn send(&mut self, dest: IterAddr<'_, SocketAddr>, message: B) -> anyhow::Result<()> {
impl<E: SendEvent<Outgoing<A, B>>, A, B: Buf> SendMessage<IterAddr<'_, A>, B> for Net<E, A> {
fn send(&mut self, dest: IterAddr<'_, A>, message: B) -> anyhow::Result<()> {
for addr in dest.0 {
SendMessage::send(self, addr, message.clone())?
}
Expand All @@ -128,13 +135,14 @@ impl<E: SendEvent<Outgoing<B>>, B: Buf> SendMessage<IterAddr<'_, SocketAddr>, B>
}

impl<
E: SendEventOnce<Closed> + Clone + Send + 'static,
P: Protocol<B>,
E: SendEventOnce<Closed<A>> + Clone + Send + 'static,
P: Protocol<A, B>,
A: Addr,
B: Buf,
F: FnMut(&[u8]) -> anyhow::Result<()> + Clone + Send + 'static,
> OnEvent for Dispatch<E, P, B, F>
> OnEvent for Dispatch<E, P, A, B, F>
{
type Event = Event<P, B>;
type Event = Event<P, A, B>;

fn on_event(&mut self, event: Self::Event, timer: &mut impl Timer) -> anyhow::Result<()> {
match event {
Expand All @@ -146,22 +154,23 @@ impl<
}

impl<
E: SendEventOnce<Closed> + Clone + Send + 'static,
P: Protocol<B>,
E: SendEventOnce<Closed<A>> + Clone + Send + 'static,
P: Protocol<A, B>,
A: Addr,
B: Buf,
F: FnMut(&[u8]) -> anyhow::Result<()> + Clone + Send + 'static,
> erased::OnEvent<Outgoing<B>> for Dispatch<E, P, B, F>
> erased::OnEvent<Outgoing<A, B>> for Dispatch<E, P, A, B, F>
{
fn on_event(
&mut self,
Outgoing(remote, buf): Outgoing<B>,
Outgoing(remote, buf): Outgoing<A, B>,
_: &mut impl Timer,
) -> anyhow::Result<()> {
if let Some(connection) = self.connections.get_mut(&remote) {
match connection.sender.send(buf.clone()) {
Ok(()) => return Ok(()),
Err(err) => {
warn!(">=> {remote} connection discontinued: {err}");
warn!(">=> {remote:?} connection discontinued: {err}");
self.connections.remove(&remote);
// in an ideal world the SendError will return the buf back to us, and we can
// directly reuse that in below, saving a `clone` above especially for fast path
Expand All @@ -173,12 +182,12 @@ impl<
}
}
self.seq += 1;
let close_guard = CloseGuard(self.close_sender.clone(), Some(remote), self.seq);
let close_guard = CloseGuard(self.close_sender.clone(), Some(remote.clone()), self.seq);
let mut sender = self
.protocol
.connect(remote, self.on_buf.clone(), close_guard);
.connect(remote.clone(), self.on_buf.clone(), close_guard);
if sender.send(buf).is_err() {
warn!(">=> {remote} new connection immediately fail")
warn!(">=> {remote:?} new connection immediately fail")
// we don't try again in such case since the remote is probably never reachable anymore
// not sure whether this should be considered as a fatal error. if this is happening,
// will it happen for every following outgoing connection?
Expand All @@ -198,11 +207,12 @@ impl<
pub struct Incoming<T>(pub T);

impl<
E: SendEventOnce<Closed> + Clone + Send + 'static,
P: Protocol<B>,
E: SendEventOnce<Closed<A>> + Clone + Send + 'static,
P: Protocol<A, B>,
A: Addr,
B: Buf,
F: FnMut(&[u8]) -> anyhow::Result<()> + Clone + Send + 'static,
> erased::OnEvent<Incoming<P::Incoming>> for Dispatch<E, P, B, F>
> erased::OnEvent<Incoming<P::Incoming>> for Dispatch<E, P, A, B, F>
{
fn on_event(
&mut self,
Expand All @@ -217,20 +227,20 @@ impl<
// always prefer to keep the connection created locally
// the connection in `self.connections` may not be created locally, but the incoming
// connection is definitely created remotely
if let Entry::Vacant(entry) = self.connections.entry(remote) {
if let Entry::Vacant(entry) = self.connections.entry(remote.clone()) {
entry.insert(Connection {
sender,
seq: self.seq,
});
} else {
warn!("<<< {remote} incoming connection from connected address")
warn!("<<< {remote:?} incoming connection from connected address")
}
Ok(())
}
}

impl<E, P: Protocol<B>, B, F> erased::OnEvent<Closed> for Dispatch<E, P, B, F> {
fn on_event(&mut self, Closed(addr, seq): Closed, _: &mut impl Timer) -> anyhow::Result<()> {
impl<E, P: Protocol<A, B>, A: Addr, B, F> erased::OnEvent<Closed<A>> for Dispatch<E, P, A, B, F> {
fn on_event(&mut self, Closed(addr, seq): Closed<A>, _: &mut impl Timer) -> anyhow::Result<()> {
if let Some(connection) = self.connections.get(&addr) {
if connection.seq == seq {
debug!(">>> {addr:?} outgoing connection closed");
Expand All @@ -241,7 +251,7 @@ impl<E, P: Protocol<B>, B, F> erased::OnEvent<Closed> for Dispatch<E, P, B, F> {
}
}

impl<E, P: Protocol<B>, B, F> OnTimer for Dispatch<E, P, B, F> {
impl<E, P: Protocol<A, B>, A, B, F> OnTimer for Dispatch<E, P, A, B, F> {
fn on_timer(&mut self, _: crate::event::TimerId, _: &mut impl Timer) -> anyhow::Result<()> {
unreachable!()
}
Expand Down
14 changes: 7 additions & 7 deletions src/net/session/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ impl Quic {
}
}

async fn write_task<B: Buf, E: SendEventOnce<Closed>>(
async fn write_task<B: Buf, E: SendEventOnce<Closed<SocketAddr>>>(
connection: quinn::Connection,
mut receiver: UnboundedReceiver<B>,
close_guard: CloseGuard<E>,
close_guard: CloseGuard<E, SocketAddr>,
) {
loop {
enum Select<B> {
Expand Down Expand Up @@ -128,14 +128,14 @@ impl Quic {
}
}

impl<B: Buf> Protocol<B> for Quic {
impl<B: Buf> Protocol<SocketAddr, B> for Quic {
type Sender = UnboundedSender<B>;

fn connect<E: SendEventOnce<Closed> + Send + 'static>(
fn connect<E: SendEventOnce<Closed<SocketAddr>> + Send + 'static>(
&self,
remote: SocketAddr,
on_buf: impl FnMut(&[u8]) -> anyhow::Result<()> + Clone + Send + 'static,
close_guard: CloseGuard<E>,
close_guard: CloseGuard<E, SocketAddr>,
) -> Self::Sender {
let endpoint = self.0.clone();
// tracing::debug!("{:?} connect {remote}", endpoint.local_addr());
Expand Down Expand Up @@ -169,10 +169,10 @@ impl<B: Buf> Protocol<B> for Quic {

type Incoming = quinn::Connection;

fn accept<E: SendEventOnce<Closed> + Send + 'static>(
fn accept<E: SendEventOnce<Closed<SocketAddr>> + Send + 'static>(
connection: Self::Incoming,
on_buf: impl FnMut(&[u8]) -> anyhow::Result<()> + Clone + Send + 'static,
close_guard: CloseGuard<E>,
close_guard: CloseGuard<E, SocketAddr>,
) -> Option<(SocketAddr, Self::Sender)> {
let remote = connection.remote_address();
tokio::spawn(Self::read_task(connection.clone(), on_buf));
Expand Down
Loading

0 comments on commit 36728d6

Please sign in to comment.