Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core/transport/memory: Return dialer address in Upgrade event #1724

Merged
merged 7 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
two peer IDs are equal if and only if they use the same hash algorithm and
have the same hash digest. [PR 1608](https://github.com/libp2p/rust-libp2p/pull/1608).

- Return dialer address instead of listener address as `remote_addr` in
`MemoryTransport` `Listener` `ListenerEvent::Upgrade`
[PR 1724](https://github.com/libp2p/rust-libp2p/pull/1724).

# 0.21.0 [2020-08-18]

- Remove duplicates when performing address translation
Expand Down
2 changes: 1 addition & 1 deletion core/src/connection/listeners.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ mod tests {
match listeners.next().await.unwrap() {
ListenersEvent::Incoming { local_addr, send_back_addr, .. } => {
assert_eq!(local_addr, address);
assert_eq!(send_back_addr, address);
assert!(send_back_addr != address);
},
_ => panic!()
}
Expand Down
242 changes: 199 additions & 43 deletions core/src/transport/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,57 @@ use rw_stream_sink::RwStreamSink;
use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin};

lazy_static! {
static ref HUB: Mutex<FnvHashMap<NonZeroU64, mpsc::Sender<Channel<Vec<u8>>>>> =
Mutex::new(FnvHashMap::default());
static ref HUB: Hub = Hub(Mutex::new(FnvHashMap::default()));
}

struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);

/// A [`mpsc::Sender`] enabling a [`DialFuture`] to send a [`Channel`] and the
/// port of the dialer to a [`Listener`].
type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;

/// A [`mpsc::Receiver`] enabling a [`Listener`] to receive a [`Channel`] and
/// the port of the dialer from a [`DialFuture`].
type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;

impl Hub {
/// Registers the given port on the hub.
///
/// Randomizes port when given port is `0`. Returns [`None`] when given port
/// is already occupied.
fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
let mut hub = self.0.lock();

let port = if let Some(port) = NonZeroU64::new(port) {
port
} else {
loop {
let port = match NonZeroU64::new(rand::random()) {
Some(p) => p,
None => continue,
};
if !hub.contains_key(&port) {
break port;
}
}
};

let (tx, rx) = mpsc::channel(2);
match hub.entry(port) {
Entry::Occupied(_) => return None,
Entry::Vacant(e) => e.insert(tx)
};

Some((rx, port))
}

fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
self.0.lock().remove(port)
}

fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
self.0.lock().get(port).cloned()
}
}

/// Transport that supports `/memory/N` multiaddresses.
Expand All @@ -38,15 +87,49 @@ pub struct MemoryTransport;

/// Connection to a `MemoryTransport` currently being opened.
pub struct DialFuture {
sender: mpsc::Sender<Channel<Vec<u8>>>,
/// Ephemeral source port.
///
/// These ports mimic TCP ephemeral source ports but are not actually used
/// by the memory transport due to the direct use of channels. They merely
/// ensure that every connection has a unique address for each dialer, which
/// is not at the same time a listen address (analogous to TCP).
dial_port: NonZeroU64,
sender: ChannelSender,
channel_to_send: Option<Channel<Vec<u8>>>,
channel_to_return: Option<Channel<Vec<u8>>>,
}

impl DialFuture {
fn new(port: NonZeroU64) -> Option<Self> {
let sender = HUB.get(&port)?.clone();

let (_dial_port_channel, dial_port) = HUB.register_port(0)
.expect("there to be some random unoccupied port.");

let (a_tx, a_rx) = mpsc::channel(4096);
let (b_tx, b_rx) = mpsc::channel(4096);
Some(DialFuture {
dial_port,
sender,
channel_to_send: Some(RwStreamSink::new(Chan {
incoming: a_rx,
outgoing: b_tx,
dial_port: None,
})),
channel_to_return: Some(RwStreamSink::new(Chan {
incoming: b_rx,
outgoing: a_tx,
dial_port: Some(dial_port),
})),
})
}
}

impl Future for DialFuture {
type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {

match self.sender.poll_ready(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => {},
Expand All @@ -55,7 +138,8 @@ impl Future for DialFuture {

let channel_to_send = self.channel_to_send.take()
.expect("Future should not be polled again once complete");
match self.sender.start_send(channel_to_send) {
let dial_port = self.dial_port;
match self.sender.start_send((channel_to_send, dial_port)) {
Err(_) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
Ok(()) => {}
}
Expand All @@ -79,28 +163,9 @@ impl Transport for MemoryTransport {
return Err(TransportError::MultiaddrNotSupported(addr));
};

let mut hub = (&*HUB).lock();

let port = if let Some(port) = NonZeroU64::new(port) {
port
} else {
loop {
let port = match NonZeroU64::new(rand::random()) {
Some(p) => p,
None => continue,
};
if !hub.contains_key(&port) {
break port;
}
}
};


let (tx, rx) = mpsc::channel(2);
match hub.entry(port) {
Entry::Occupied(_) =>
return Err(TransportError::Other(MemoryTransportError::Unreachable)),
Entry::Vacant(e) => e.insert(tx)
let (rx, port) = match HUB.register_port(port) {
Some((rx, port)) => (rx, port),
None => return Err(TransportError::Other(MemoryTransportError::Unreachable)),
};

let listener = Listener {
Expand All @@ -124,19 +189,7 @@ impl Transport for MemoryTransport {
return Err(TransportError::MultiaddrNotSupported(addr));
};

let hub = HUB.lock();
if let Some(sender) = hub.get(&port) {
let (a_tx, a_rx) = mpsc::channel(4096);
let (b_tx, b_rx) = mpsc::channel(4096);
Ok(DialFuture {
sender: sender.clone(),
channel_to_send: Some(RwStreamSink::new(Chan { incoming: a_rx, outgoing: b_tx })),
channel_to_return: Some(RwStreamSink::new(Chan { incoming: b_rx, outgoing: a_tx })),

})
} else {
Err(TransportError::Other(MemoryTransportError::Unreachable))
}
DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
}
}

Expand Down Expand Up @@ -167,7 +220,7 @@ pub struct Listener {
/// The address we are listening on.
addr: Multiaddr,
/// Receives incoming connections.
receiver: mpsc::Receiver<Channel<Vec<u8>>>,
receiver: ChannelReceiver,
/// Generate `ListenerEvent::NewAddress` to inform about our listen address.
tell_listen_addr: bool
}
Expand All @@ -181,7 +234,7 @@ impl Stream for Listener {
return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone()))))
}

let channel = match Stream::poll_next(Pin::new(&mut self.receiver), cx) {
let (channel, dial_port) = match Stream::poll_next(Pin::new(&mut self.receiver), cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => panic!("Alive listeners always have a sender."),
Poll::Ready(Some(v)) => v,
Expand All @@ -190,7 +243,7 @@ impl Stream for Listener {
let event = ListenerEvent::Upgrade {
upgrade: future::ready(Ok(channel)),
local_addr: self.addr.clone(),
remote_addr: Protocol::Memory(self.port.get()).into()
remote_addr: Protocol::Memory(dial_port.get()).into()
};

Poll::Ready(Some(Ok(event)))
Expand All @@ -199,7 +252,7 @@ impl Stream for Listener {

impl Drop for Listener {
fn drop(&mut self) {
let val_in = HUB.lock().remove(&self.port);
let val_in = HUB.unregister_port(&self.port);
debug_assert!(val_in.is_some());
}
}
Expand Down Expand Up @@ -232,6 +285,14 @@ pub type Channel<T> = RwStreamSink<Chan<T>>;
pub struct Chan<T = Vec<u8>> {
incoming: mpsc::Receiver<T>,
outgoing: mpsc::Sender<T>,

// Needed in [`Drop`] implementation of [`Chan`] to unregister the dialing
// port with the global [`HUB`]. Is [`Some`] when [`Chan`] of dialer and
// [`None`] when [`Chan`] of listener.
//
// Note: Listening port is unregistered in [`Drop`] implementation of
// [`Listener`].
dial_port: Option<NonZeroU64>,
}

impl<T> Unpin for Chan<T> {
Expand Down Expand Up @@ -276,6 +337,15 @@ impl<T: AsRef<[u8]>> Into<RwStreamSink<Chan<T>>> for Chan<T> {
}
}

impl<T> Drop for Chan<T> {
fn drop(&mut self) {
if let Some(port) = self.dial_port {
let channel_sender = HUB.unregister_port(&port);
debug_assert!(channel_sender.is_some());
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -350,4 +420,90 @@ mod tests {

futures::executor::block_on(futures::future::join(listener, dialer));
}

#[test]
fn dialer_address_unequal_to_listener_address() {
let listener_addr: Multiaddr = Protocol::Memory(
rand::random::<u64>().saturating_add(1),
).into();
let listener_addr_cloned = listener_addr.clone();

let listener_transport = MemoryTransport::default();

let listener = async move {
let mut listener = listener_transport.listen_on(listener_addr.clone())
.unwrap();
while let Some(ev) = listener.next().await {
if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
assert!(
remote_addr != listener_addr,
"Expect dialer address not to equal listener address."
);
return;
}
}
};

let dialer = async move {
MemoryTransport::default().dial(listener_addr_cloned)
.unwrap()
.await
.unwrap();
};

futures::executor::block_on(futures::future::join(listener, dialer));
}

#[test]
fn dialer_port_is_deregistered() {
let (terminate, should_terminate) = futures::channel::oneshot::channel();
let (terminated, is_terminated) = futures::channel::oneshot::channel();

let listener_addr: Multiaddr = Protocol::Memory(
rand::random::<u64>().saturating_add(1),
).into();
let listener_addr_cloned = listener_addr.clone();

let listener_transport = MemoryTransport::default();

let listener = async move {
let mut listener = listener_transport.listen_on(listener_addr.clone())
.unwrap();
while let Some(ev) = listener.next().await {
if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
let dialer_port = NonZeroU64::new(
parse_memory_addr(&remote_addr).unwrap(),
).unwrap();

assert!(
HUB.get(&dialer_port).is_some(),
"Expect dialer port to stay registered while connection is in use.",
);

terminate.send(()).unwrap();
is_terminated.await.unwrap();

assert!(
HUB.get(&dialer_port).is_none(),
"Expect dialer port to be deregistered once connection is dropped.",
);

return;
}
}
};

let dialer = async move {
let _chan = MemoryTransport::default().dial(listener_addr_cloned)
.unwrap()
.await
.unwrap();

should_terminate.await.unwrap();
drop(_chan);
terminated.send(()).unwrap();
};

futures::executor::block_on(futures::future::join(listener, dialer));
}
}