Skip to content

Commit

Permalink
multistream-select: Enforce io::error instead of empty protocols (#318
Browse files Browse the repository at this point in the history
)

This PR brings parity between the litep2p mutlistream-select
implementation and the libp2p one.

There was a mismatch in the litep2p implementation which resulted in
decoding empty bytes into `Message::Protocols([ ])`. In contrast, libp2p
returns an `io::error` since the message is invalid.

While at it have added a few tests to ensure our implementation works as
expected

cc @paritytech/networking

---------

Signed-off-by: Alexandru Vasile <[email protected]>
  • Loading branch information
lexnv authored Jan 29, 2025
1 parent 78d934f commit b7511c8
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 47 deletions.
33 changes: 17 additions & 16 deletions src/multistream_select/dialer_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ use crate::{
error::{self, Error, ParseError},
multistream_select::{
protocol::{
encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
ProtocolError,
},
Negotiated, NegotiationError, Version,
},
Expand Down Expand Up @@ -305,7 +306,7 @@ pub enum HandshakeResult {
/// Handshake state.
#[derive(Debug)]
enum HandshakeState {
/// Wainting to receive any response from remote peer.
/// Waiting to receive any response from remote peer.
WaitingResponse,

/// Waiting to receive the actual application protocol from remote peer.
Expand All @@ -314,7 +315,7 @@ enum HandshakeState {

/// `multistream-select` dialer handshake state.
#[derive(Debug)]
pub struct DialerState {
pub struct WebRtcDialerState {
/// Proposed main protocol.
protocol: ProtocolName,

Expand All @@ -325,16 +326,16 @@ pub struct DialerState {
state: HandshakeState,
}

impl DialerState {
impl WebRtcDialerState {
/// Propose protocol to remote peer.
///
/// Return [`DialerState`] which is used to drive forward the negotiation and an encoded
/// Return [`WebRtcDialerState`] which is used to drive forward the negotiation and an encoded
/// `multistream-select` message that contains the protocol proposal for the substream.
pub fn propose(
protocol: ProtocolName,
fallback_names: Vec<ProtocolName>,
) -> crate::Result<(Self, Vec<u8>)> {
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
std::iter::once(protocol.clone())
.chain(fallback_names.clone())
.filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok())
Expand All @@ -353,7 +354,7 @@ impl DialerState {
))
}

/// Register response to [`DialerState`].
/// Register response to [`WebRtcDialerState`].
pub fn register_response(
&mut self,
payload: Vec<u8>,
Expand Down Expand Up @@ -755,7 +756,7 @@ mod tests {
#[test]
fn propose() {
let (mut dialer_state, message) =
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
let message = bytes::BytesMut::from(&message[..]).freeze();

let Message::Protocols(protocols) = Message::decode(message).unwrap() else {
Expand All @@ -777,7 +778,7 @@ mod tests {

#[test]
fn propose_with_fallback() {
let (mut dialer_state, message) = DialerState::propose(
let (mut dialer_state, message) = WebRtcDialerState::propose(
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
Expand Down Expand Up @@ -813,7 +814,7 @@ mod tests {
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

let (mut dialer_state, _message) =
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();

match dialer_state.register_response(bytes.freeze().to_vec()) {
Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
Expand All @@ -832,7 +833,7 @@ mod tests {
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

let (mut dialer_state, _message) =
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();

match dialer_state.register_response(bytes.freeze().to_vec()) {
Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
Expand All @@ -842,7 +843,7 @@ mod tests {

#[test]
fn negotiate_main_protocol() {
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![Message::Protocol(
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
)]
Expand All @@ -851,7 +852,7 @@ mod tests {
.unwrap()
.freeze();

let (mut dialer_state, _message) = DialerState::propose(
let (mut dialer_state, _message) = WebRtcDialerState::propose(
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
Expand All @@ -860,13 +861,13 @@ mod tests {
match dialer_state.register_response(message.to_vec()) {
Ok(HandshakeResult::Succeeded(negotiated)) =>
assert_eq!(negotiated, ProtocolName::from("/13371338/proto/1")),
_ => panic!("invalid event"),
event => panic!("invalid event {event:?}"),
}
}

#[test]
fn negotiate_fallback_protocol() {
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![Message::Protocol(
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
)]
Expand All @@ -875,7 +876,7 @@ mod tests {
.unwrap()
.freeze();

let (mut dialer_state, _message) = DialerState::propose(
let (mut dialer_state, _message) = WebRtcDialerState::propose(
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
Expand Down
36 changes: 19 additions & 17 deletions src/multistream_select/listener_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ use crate::{
error::{self, Error},
multistream_select::{
protocol::{
encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
ProtocolError,
},
Negotiated, NegotiationError,
},
Expand Down Expand Up @@ -324,7 +325,7 @@ where
}
}

/// Result of [`listener_negotiate()`].
/// Result of [`webrtc_listener_negotiate()`].
#[derive(Debug)]
pub enum ListenerSelectResult {
/// Requested protocol is available and substream can be accepted.
Expand All @@ -348,7 +349,7 @@ pub enum ListenerSelectResult {
/// Parse protocols offered by the remote peer and check if any of the offered protocols match
/// locally available protocols. If a match is found, return an encoded multistream-select
/// response and the negotiated protocol. If parsing fails or no match is found, return an error.
pub fn listener_negotiate<'a>(
pub fn webrtc_listener_negotiate<'a>(
supported_protocols: &'a mut impl Iterator<Item = &'a ProtocolName>,
payload: Bytes,
) -> crate::Result<ListenerSelectResult> {
Expand Down Expand Up @@ -382,9 +383,9 @@ pub fn listener_negotiate<'a>(
if protocol.as_ref() == supported.as_bytes() {
return Ok(ListenerSelectResult::Accepted {
protocol: supported.clone(),
message: encode_multistream_message(std::iter::once(Message::Protocol(
protocol,
)))?,
message: webrtc_encode_multistream_message(std::iter::once(
Message::Protocol(protocol),
))?,
});
}
}
Expand All @@ -396,7 +397,7 @@ pub fn listener_negotiate<'a>(
);

Ok(ListenerSelectResult::Rejected {
message: encode_multistream_message(std::iter::once(Message::NotAvailable))?,
message: webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))?,
})
}

Expand All @@ -405,15 +406,15 @@ mod tests {
use super::*;

#[test]
fn listener_negotiate_works() {
fn webrtc_listener_negotiate_works() {
let mut local_protocols = vec![
ProtocolName::from("/13371338/proto/1"),
ProtocolName::from("/sup/proto/1"),
ProtocolName::from("/13371338/proto/2"),
ProtocolName::from("/13371338/proto/3"),
ProtocolName::from("/13371338/proto/4"),
];
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![
Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()),
Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()),
Expand All @@ -423,7 +424,7 @@ mod tests {
.unwrap()
.freeze();

match listener_negotiate(&mut local_protocols.iter(), message) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
Err(error) => panic!("error received: {error:?}"),
Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"),
Ok(ListenerSelectResult::Accepted { protocol, message }) => {
Expand All @@ -441,14 +442,14 @@ mod tests {
ProtocolName::from("/13371338/proto/3"),
ProtocolName::from("/13371338/proto/4"),
];
let message = encode_multistream_message(std::iter::once(Message::Protocols(vec![
let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocols(vec![
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
])))
.unwrap()
.freeze();

match listener_negotiate(&mut local_protocols.iter(), message) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
Err(error) => assert!(std::matches!(error, Error::InvalidData)),
_ => panic!("invalid event"),
}
Expand All @@ -469,7 +470,7 @@ mod tests {
let message = Message::Header(HeaderLine::V1);
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
Err(error) => assert!(std::matches!(
error,
Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
Expand Down Expand Up @@ -498,7 +499,7 @@ mod tests {
]);
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
Err(error) => assert!(std::matches!(
error,
Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
Expand All @@ -518,7 +519,7 @@ mod tests {
ProtocolName::from("/13371338/proto/3"),
ProtocolName::from("/13371338/proto/4"),
];
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![Message::Protocol(
Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(),
)]
Expand All @@ -527,12 +528,13 @@ mod tests {
.unwrap()
.freeze();

match listener_negotiate(&mut local_protocols.iter(), message) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
Err(error) => panic!("error received: {error:?}"),
Ok(ListenerSelectResult::Rejected { message }) => {
assert_eq!(
message,
encode_multistream_message(std::iter::once(Message::NotAvailable)).unwrap()
webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))
.unwrap()
);
}
Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"),
Expand Down
5 changes: 3 additions & 2 deletions src/multistream_select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ mod negotiated;
mod protocol;

pub use crate::multistream_select::{
dialer_select::{dialer_select_proto, DialerSelectFuture, DialerState, HandshakeResult},
dialer_select::{dialer_select_proto, DialerSelectFuture, HandshakeResult, WebRtcDialerState},
listener_select::{
listener_negotiate, listener_select_proto, ListenerSelectFuture, ListenerSelectResult,
listener_select_proto, webrtc_listener_negotiate, ListenerSelectFuture,
ListenerSelectResult,
},
negotiated::{Negotiated, NegotiatedComplete, NegotiationError},
protocol::{HeaderLine, Message, Protocol, ProtocolError},
Expand Down
81 changes: 78 additions & 3 deletions src/multistream_select/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ impl Message {
let mut remaining: &[u8] = &msg;
loop {
// A well-formed message must be terminated with a newline.
// TODO: don't do this
if remaining == [b'\n'] || remaining.is_empty() {
if remaining == [b'\n'] {
break;
} else if protocols.len() == MAX_PROTOCOLS {
return Err(ProtocolError::TooManyProtocols);
Expand All @@ -228,7 +227,12 @@ impl Message {
}

/// Create `multistream-select` message from an iterator of `Message`s.
pub fn encode_multistream_message(
///
/// # Note
///
/// This is implementation is not compliant with the multistream-select protocol spec.
/// The only purpose of this was to get the `multistream-select` protocol working with smoldot.
pub fn webrtc_encode_multistream_message(
messages: impl IntoIterator<Item = Message>,
) -> crate::Result<BytesMut> {
// encode `/multistream-select/1.0.0` header
Expand All @@ -245,6 +249,9 @@ pub fn encode_multistream_message(
header.append(&mut proto_bytes);
}

// For the `Message::Protocols` to be interpreted correctly, it must be followed by a newline.
header.push(b'\n');

Ok(BytesMut::from(&header[..]))
}

Expand Down Expand Up @@ -468,3 +475,71 @@ impl From<uvi::decode::Error> for ProtocolError {
Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_decode_main_messages() {
// Decode main messages.
let bytes = Bytes::from_static(MSG_MULTISTREAM_1_0);
assert_eq!(
Message::decode(bytes).unwrap(),
Message::Header(HeaderLine::V1)
);

let bytes = Bytes::from_static(MSG_PROTOCOL_NA);
assert_eq!(Message::decode(bytes).unwrap(), Message::NotAvailable);

let bytes = Bytes::from_static(MSG_LS);
assert_eq!(Message::decode(bytes).unwrap(), Message::ListProtocols);
}

#[test]
fn test_decode_empty_message() {
// Empty message should decode to an IoError, not Header::Protocols.
let bytes = Bytes::from_static(b"");
match Message::decode(bytes).unwrap_err() {
ProtocolError::IoError(io) => assert_eq!(io.kind(), io::ErrorKind::InvalidData),
err => panic!("Unexpected error: {:?}", err),
};
}

#[test]
fn test_decode_protocols() {
// Single protocol.
let bytes = Bytes::from_static(b"/protocol-v1\n");
assert_eq!(
Message::decode(bytes).unwrap(),
Message::Protocol(Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap())
);

// Multiple protocols.
let expected = Message::Protocols(vec![
Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap(),
Protocol::try_from(Bytes::from_static(b"/protocol-v2")).unwrap(),
]);
let mut encoded = BytesMut::new();
expected.encode(&mut encoded).unwrap();

// `\r` is the length of the protocol names.
let bytes = Bytes::from_static(b"\r/protocol-v1\n\r/protocol-v2\n\n");
assert_eq!(encoded, bytes);

assert_eq!(
Message::decode(bytes).unwrap(),
Message::Protocols(vec![
Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap(),
Protocol::try_from(Bytes::from_static(b"/protocol-v2")).unwrap(),
])
);

// Check invalid length.
let bytes = Bytes::from_static(b"\r/v1\n\n");
assert_eq!(
Message::decode(bytes).unwrap_err(),
ProtocolError::InvalidMessage
);
}
}
Loading

0 comments on commit b7511c8

Please sign in to comment.