diff --git a/examples/udp_vpv4_multicast/Cargo.toml b/examples/udp_vpv4_multicast/Cargo.toml new file mode 100644 index 0000000..ac6f0e9 --- /dev/null +++ b/examples/udp_vpv4_multicast/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "udp_ipv4_multicast" +version = "0.1.0" +edition = "2024" +publish = false + +[dependencies] +tokio = "1" +turmoil = { path = "../.." } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/udp_vpv4_multicast/src/main.rs b/examples/udp_vpv4_multicast/src/main.rs new file mode 100644 index 0000000..3493f3d --- /dev/null +++ b/examples/udp_vpv4_multicast/src/main.rs @@ -0,0 +1,56 @@ +use std::{net::Ipv4Addr, time::Duration}; +use tracing::info; +use turmoil::{IpVersion, net::UdpSocket}; + +const N_STEPS: usize = 3; + +fn main() { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(tracing::level_filters::LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .init(); + + let tick = Duration::from_millis(100); + let mut sim = turmoil::Builder::new() + .tick_duration(tick) + .ip_version(IpVersion::V4) + .build(); + + let multicast_port = 9000; + let multicast_addr = "239.0.0.1".parse().unwrap(); + + for server_index in 0..2 { + sim.client(format!("server-{server_index}"), async move { + let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, multicast_port)).await?; + socket.join_multicast_v4(multicast_addr, Ipv4Addr::UNSPECIFIED)?; + + let mut buf = [0; 1024]; + for _ in 0..N_STEPS { + let (n, addr) = socket.recv_from(&mut buf).await?; + let data = &buf[0..n]; + + info!("UDP packet from {} has been received: {:?}", addr, data); + } + Ok(()) + }); + } + + sim.client("client", async move { + let dst = (multicast_addr, multicast_port); + let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await?; + + for _ in 0..N_STEPS { + let _ = socket.send_to(&[1, 2, 3], dst).await?; + info!("UDP packet has been sent"); + + tokio::time::sleep(tick).await; + } + + Ok(()) + }); + + sim.run().unwrap(); +} diff --git a/examples/udp_vpv6_multicast/Cargo.toml b/examples/udp_vpv6_multicast/Cargo.toml new file mode 100644 index 0000000..f4f4414 --- /dev/null +++ b/examples/udp_vpv6_multicast/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "udp_ipv6_multicast" +version = "0.1.0" +edition = "2024" +publish = false + +[dependencies] +tokio = "1" +turmoil = { path = "../.." } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/udp_vpv6_multicast/src/main.rs b/examples/udp_vpv6_multicast/src/main.rs new file mode 100644 index 0000000..9f7a39e --- /dev/null +++ b/examples/udp_vpv6_multicast/src/main.rs @@ -0,0 +1,56 @@ +use std::{net::Ipv6Addr, time::Duration}; +use tracing::info; +use turmoil::{IpVersion, net::UdpSocket}; + +const N_STEPS: usize = 3; + +fn main() { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(tracing::level_filters::LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .init(); + + let tick = Duration::from_millis(100); + let mut sim = turmoil::Builder::new() + .tick_duration(tick) + .ip_version(IpVersion::V6) + .build(); + + let multicast_port = 9000; + let multicast_addr = "ff08::1".parse().unwrap(); + + for host_index in 0..2 { + sim.client(format!("server-{host_index}"), async move { + let socket = UdpSocket::bind((Ipv6Addr::UNSPECIFIED, multicast_port)).await?; + socket.join_multicast_v6(&multicast_addr, 0)?; + + let mut buf = [0; 1024]; + for _ in 0..N_STEPS { + let (n, addr) = socket.recv_from(&mut buf).await?; + let data = &buf[0..n]; + + info!("UDP packet from {} has been received: {:?}", addr, data); + } + Ok(()) + }); + } + + sim.client("client", async move { + let dst = (multicast_addr, multicast_port); + let socket = UdpSocket::bind((Ipv6Addr::UNSPECIFIED, 0)).await?; + + for _ in 0..N_STEPS { + let _ = socket.send_to(&[1, 2, 3], dst).await?; + info!("UDP packet has been sent"); + + tokio::time::sleep(tick).await; + } + + Ok(()) + }); + + sim.run().unwrap(); +} diff --git a/src/envelope.rs b/src/envelope.rs index 1f5690e..eb111ee 100644 --- a/src/envelope.rs +++ b/src/envelope.rs @@ -18,7 +18,7 @@ pub enum Protocol { } /// UDP datagram. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Datagram(pub Bytes); /// This is a simplification of real TCP. diff --git a/src/net/mod.rs b/src/net/mod.rs index c218314..4f67e80 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -8,7 +8,7 @@ use std::net::SocketAddr; pub mod tcp; pub use tcp::{listener::TcpListener, stream::TcpStream}; -mod udp; +pub(crate) mod udp; pub use udp::UdpSocket; #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] diff --git a/src/net/udp.rs b/src/net/udp.rs index de2e9b7..c6bd102 100644 --- a/src/net/udp.rs +++ b/src/net/udp.rs @@ -1,4 +1,6 @@ use bytes::Bytes; +use indexmap::{IndexMap, IndexSet}; +use std::net::SocketAddr; use tokio::{ sync::{mpsc, Mutex}, time::sleep, @@ -13,7 +15,7 @@ use crate::{ use std::{ cmp, io::{self, Error, ErrorKind, Result}, - net::{Ipv6Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; /// A simulated UDP socket. @@ -32,6 +34,60 @@ impl std::fmt::Debug for UdpSocket { } } +#[derive(Debug, Default)] +pub(crate) struct MulticastGroups(IndexMap>); + +impl MulticastGroups { + fn destination_addresses(&self, multiaddr: &IpAddr) -> IndexSet { + self.0.get(multiaddr).cloned().unwrap_or_default() + } + + fn contains_destination_address(&self, multiaddr: &IpAddr, addr: &SocketAddr) -> bool { + self.0 + .get(multiaddr) + .and_then(|group| group.get(addr)) + .is_some() + } + + fn join(&mut self, multiaddr: IpAddr, addr: SocketAddr) { + self.0 + .entry(multiaddr) + .and_modify(|addrs| { + addrs.insert(addr); + tracing::info!(target: TRACING_TARGET, ?addr, group = ?multiaddr, protocol = %"UDP", "Join group"); + }) + .or_insert_with(|| IndexSet::from([addr])); + } + + fn leave(&mut self, multiaddr: IpAddr, addr: &SocketAddr) { + let index = self + .0 + .entry(multiaddr) + .and_modify(|group| { + group.swap_remove(addr); + tracing::info!(target: TRACING_TARGET, ?addr, group = ?multiaddr, protocol = %"UDP", "Leave group"); + }) + .index(); + + if self + .0 + .get_index(index) + .map(|(_, group)| group.is_empty()) + .unwrap_or(false) + { + self.0.swap_remove_index(index); + } + } + + fn leave_all(&mut self, addr: &SocketAddr) { + for (multiaddr, group) in self.0.iter_mut() { + group.swap_remove(addr); + tracing::info!(target: TRACING_TARGET, ?addr, group = ?multiaddr, protocol = %"UDP", "Leave group"); + } + self.0.retain(|_, group| !group.is_empty()); + } +} + struct Rx { recv: mpsc::Receiver<(Datagram, SocketAddr)>, /// A buffered received message. @@ -115,16 +171,7 @@ impl UdpSocket { let mut addr = addr.to_socket_addr(&world.dns); let host = world.current_host_mut(); - if !addr.ip().is_unspecified() && !addr.ip().is_loopback() { - return Err(Error::new( - ErrorKind::AddrNotAvailable, - format!("{addr} is not supported"), - )); - } - - if addr.is_ipv4() != host.addr.is_ipv4() { - panic!("ip version mismatch: {:?} host: {:?}", addr, host.addr) - } + verify_ipv4_bind_interface(addr.ip(), host.addr)?; if addr.port() == 0 { addr.set_port(host.assign_ephemeral_port()); @@ -289,8 +336,6 @@ impl UdpSocket { } fn send(&self, world: &mut World, dst: SocketAddr, packet: Datagram) -> Result<()> { - let msg = Protocol::Udp(packet); - let mut src = self.local_addr; if dst.ip().is_loopback() { src.set_ip(dst.ip()); @@ -299,10 +344,16 @@ impl UdpSocket { src.set_ip(world.current_host_mut().addr); } - if is_same(src, dst) { - send_loopback(src, dst, msg); + if dst.ip().is_multicast() { + world + .multicast_groups + .destination_addresses(&dst.ip()) + .into_iter() + .try_for_each(|dst| world.send_message(src, dst, Protocol::Udp(packet.clone())))? + } else if is_same(src, dst) { + send_loopback(src, dst, Protocol::Udp(packet)); } else { - world.send_message(src, dst, msg)?; + world.send_message(src, dst, Protocol::Udp(packet))?; } Ok(()) @@ -314,10 +365,122 @@ impl UdpSocket { Ok(()) } - /// Has no effect in turmoil. API parity with - /// https://docs.rs/tokio/latest/tokio/net/struct.UdpSocket.html#method.join_multicast_v6 - pub fn join_multicast_v6(&self, _multiaddr: &Ipv6Addr, _interface: u32) -> Result<()> { - Ok(()) + /// Executes an operation of the `IP_ADD_MEMBERSHIP` type. + /// + /// This function specifies a new multicast group for this socket to join. + /// The address must be a valid multicast address, and `interface` is the + /// address of the local interface with which the system should join the + /// multicast group. If it's equal to `INADDR_ANY` then an appropriate + /// interface is chosen by the system. + /// + /// Currently, the `interface` argument only supports `127.0.0.1` and `0.0.0.0`. + pub fn join_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> Result<()> { + if !multiaddr.is_multicast() { + return Err(Error::new( + ErrorKind::InvalidInput, + "Invalid multicast address", + )); + } + + World::current(|world| { + let dst = destination_address(world, self); + verify_ipv4_bind_interface(interface, dst.ip())?; + + world.multicast_groups.join(IpAddr::V4(multiaddr), dst); + + Ok(()) + }) + } + + /// Executes an operation of the `IPV6_ADD_MEMBERSHIP` type. + /// + /// This function specifies a new multicast group for this socket to join. + /// The address must be a valid multicast address, and `interface` is the + /// index of the interface to join/leave (or 0 to indicate any interface). + /// + /// Currently, the `interface` argument only supports `0`. + pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> Result<()> { + verify_ipv6_bind_interface(interface)?; + if !multiaddr.is_multicast() { + return Err(Error::new( + ErrorKind::InvalidInput, + "Invalid multicast address", + )); + } + + World::current(|world| { + let dst = destination_address(world, self); + + world.multicast_groups.join(IpAddr::V6(*multiaddr), dst); + + Ok(()) + }) + } + + /// Executes an operation of the `IP_DROP_MEMBERSHIP` type. + /// + /// For more information about this option, see [`join_multicast_v4`]. + /// + /// [`join_multicast_v4`]: method@Self::join_multicast_v4 + pub fn leave_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> { + if !multiaddr.is_multicast() { + return Err(Error::new( + ErrorKind::InvalidInput, + "Invalid multicast address", + )); + } + + World::current(|world| { + let dst = destination_address(world, self); + verify_ipv4_bind_interface(interface, dst.ip())?; + + if !world + .multicast_groups + .contains_destination_address(&IpAddr::V4(multiaddr), &dst) + { + return Err(Error::new( + ErrorKind::AddrNotAvailable, + "Leaving a multicast group that has not been previously joined", + )); + } + + world.multicast_groups.leave(IpAddr::V4(multiaddr), &dst); + + Ok(()) + }) + } + + /// Executes an operation of the `IPV6_DROP_MEMBERSHIP` type. + /// + /// For more information about this option, see [`join_multicast_v6`]. + /// + /// [`join_multicast_v6`]: method@Self::join_multicast_v6 + pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { + verify_ipv6_bind_interface(interface)?; + if !multiaddr.is_multicast() { + return Err(Error::new( + ErrorKind::InvalidInput, + "Invalid multicast address", + )); + } + + World::current(|world| { + let dst = destination_address(world, self); + + if !world + .multicast_groups + .contains_destination_address(&IpAddr::V6(*multiaddr), &dst) + { + return Err(Error::new( + ErrorKind::AddrNotAvailable, + "Leaving a multicast group that has not been previously joined", + )); + } + + world.multicast_groups.leave(IpAddr::V6(*multiaddr), &dst); + + Ok(()) + }) } } @@ -338,8 +501,112 @@ fn send_loopback(src: SocketAddr, dst: SocketAddr, message: Protocol) { }); } +fn verify_ipv4_bind_interface(interface: A, addr: IpAddr) -> Result<()> +where + A: Into, +{ + let interface = interface.into(); + + if !interface.is_unspecified() && !interface.is_loopback() { + return Err(Error::new( + ErrorKind::AddrNotAvailable, + format!("{interface} is not supported"), + )); + } + + if interface.is_ipv4() != addr.is_ipv4() { + panic!("ip version mismatch: {:?} host: {:?}", interface, addr) + } + + Ok(()) +} + +fn verify_ipv6_bind_interface(interface: u32) -> Result<()> { + if interface != 0 { + return Err(Error::new( + ErrorKind::AddrNotAvailable, + format!("interface {interface} is not supported"), + )); + } + + Ok(()) +} + +fn destination_address(world: &World, socket: &UdpSocket) -> SocketAddr { + let local_port = socket + .local_addr() + .expect("local_addr is always present in simulation") + .port(); + let host_addr = world.current_host().addr; + SocketAddr::from((host_addr, local_port)) +} + impl Drop for UdpSocket { fn drop(&mut self) { - World::current_if_set(|world| world.current_host_mut().udp.unbind(self.local_addr)); + World::current_if_set(|world| { + world + .multicast_groups + .leave_all(&destination_address(world, self)); + world.current_host_mut().udp.unbind(self.local_addr); + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + mod multicast_group { + use super::*; + + #[test] + fn joining_does_not_produce_duplicate_addresses() { + let addr = "[fe80::1]:9000".parse().unwrap(); + let multiaddr = "ff08::1".parse().unwrap(); + let mut groups = MulticastGroups::default(); + groups.join(multiaddr, addr); + groups.join(multiaddr, addr); + + let addrs = groups.0.values().flatten().collect::>(); + assert_eq!(addrs.as_slice(), &[&addr]); + } + + #[test] + fn leaving_does_not_remove_entire_group() { + let addr1 = "[fe80::1]:9000".parse().unwrap(); + let addr2 = "[fe80::2]:9000".parse().unwrap(); + let multiaddr = "ff08::1".parse().unwrap(); + let mut groups = MulticastGroups::default(); + groups.join(multiaddr, addr1); + groups.join(multiaddr, addr2); + groups.leave(multiaddr, &addr2); + + let addrs = groups.0.values().flatten().collect::>(); + assert_eq!(addrs.as_slice(), &[&addr1]); + } + + #[test] + fn leaving_removes_empty_group() { + let addr = "[fe80::1]:9000".parse().unwrap(); + let multiaddr = "ff08::1".parse().unwrap(); + let mut groups = MulticastGroups::default(); + groups.join(multiaddr, addr); + groups.leave(multiaddr, &addr); + + assert_eq!(groups.0.len(), 0); + } + + #[test] + fn leaving_removes_empty_groups() { + let addr = "[fe80::1]:9000".parse().unwrap(); + let multiaddr1 = "ff08::1".parse().unwrap(); + let multiaddr2 = "ff08::2".parse().unwrap(); + let mut groups = MulticastGroups::default(); + groups.join(multiaddr1, addr); + groups.join(multiaddr2, addr); + groups.leave_all(&addr); + + assert_eq!(groups.0.len(), 0); + } } } diff --git a/src/world.rs b/src/world.rs index 9765528..8dad647 100644 --- a/src/world.rs +++ b/src/world.rs @@ -2,6 +2,7 @@ use crate::config::Config; use crate::envelope::Protocol; use crate::host::HostTimer; use crate::ip::IpVersionAddrIter; +use crate::net::udp::MulticastGroups; use crate::{ config, for_pairs, Dns, Host, Result as TurmoilResult, ToIpAddr, ToIpAddrs, Topology, TRACING_TARGET, @@ -26,6 +27,9 @@ pub(crate) struct World { /// Maps hostnames to ip addresses. pub(crate) dns: Dns, + // Maps multicast groups to udp destination addresses. + pub(crate) multicast_groups: MulticastGroups, + /// If set, this is the current host being executed. pub(crate) current: Option, @@ -52,6 +56,7 @@ impl World { hosts: IndexMap::new(), topology: Topology::new(link), dns: Dns::new(addrs), + multicast_groups: MulticastGroups::default(), current: None, rng, tick_duration, diff --git a/tests/udp.rs b/tests/udp.rs index b4f5925..d19f644 100644 --- a/tests/udp.rs +++ b/tests/udp.rs @@ -95,6 +95,96 @@ fn try_recv_pong(sock: &net::UdpSocket) -> Result<()> { Ok(()) } +#[test] +fn udp_ipv4_multicast() -> Result { + let mut sim = Builder::new() + .tick_duration(Duration::from_millis(50)) + .ip_version(IpVersion::V4) + .build(); + + let multicast_port = 9000; + let multicast_addr = "239.0.0.1".parse().unwrap(); + for server_index in 0..3 { + sim.client(format!("server-{server_index}"), async move { + let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, multicast_port)).await?; + socket.join_multicast_v4(multicast_addr, Ipv4Addr::UNSPECIFIED)?; + + let mut buf = [0; 1]; + socket.recv_from(&mut buf).await?; + assert_eq!([1], buf); + + socket.leave_multicast_v4(multicast_addr, Ipv4Addr::UNSPECIFIED)?; + + let is_timed_out = timeout(Duration::from_secs(1), socket.recv_from(&mut buf)) + .await + .is_err(); + assert!(is_timed_out); + + Ok(()) + }); + } + sim.client("client", async move { + let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await?; + let dst = (multicast_addr, multicast_port); + + let _ = socket.send_to(&[1], dst).await?; + + // Give ”server” time to leave the multicast group. + tokio::time::sleep(Duration::from_millis(100)).await; + + let _ = socket.send_to(&[2], dst).await?; + + Ok(()) + }); + + sim.run() +} + +#[test] +fn udp_ipv6_multicast() -> Result { + let mut sim = Builder::new() + .tick_duration(Duration::from_millis(50)) + .ip_version(IpVersion::V6) + .build(); + + let multicast_port = 9000; + let multicast_addr = "ff08::1".parse().unwrap(); + for server_index in 0..3 { + sim.client(format!("server-{server_index}"), async move { + let socket = UdpSocket::bind((Ipv6Addr::UNSPECIFIED, multicast_port)).await?; + socket.join_multicast_v6(&multicast_addr, 0)?; + + let mut buf = [0; 1]; + socket.recv_from(&mut buf).await?; + assert_eq!([1], buf); + + socket.leave_multicast_v6(&multicast_addr, 0)?; + + let is_timed_out = timeout(Duration::from_secs(1), socket.recv_from(&mut buf)) + .await + .is_err(); + assert!(is_timed_out); + + Ok(()) + }); + } + sim.client("client", async move { + let socket = UdpSocket::bind((Ipv6Addr::UNSPECIFIED, 0)).await?; + let dst = (multicast_addr, multicast_port); + + let _ = socket.send_to(&[1], dst).await?; + + // Give ”server” time to leave the multicast group. + tokio::time::sleep(Duration::from_millis(100)).await; + + let _ = socket.send_to(&[2], dst).await?; + + Ok(()) + }); + + sim.run() +} + #[test] fn ping_pong() -> Result { let mut sim = Builder::new().build(); @@ -400,7 +490,7 @@ fn bind_ipv6_version_missmatch() { } #[test] -fn non_zero_bind() -> Result { +fn non_zero_ipv4_bind() -> Result { let mut sim = Builder::new().ip_version(IpVersion::V4).build(); sim.client("client", async move { let sock = UdpSocket::bind("1.1.1.1:1").await; @@ -408,7 +498,7 @@ fn non_zero_bind() -> Result { let Err(err) = sock else { panic!("socket creation should have failed") }; - assert_eq!(err.to_string(), "1.1.1.1:1 is not supported"); + assert_eq!(err.to_string(), "1.1.1.1 is not supported"); Ok(()) }); sim.run()