diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index 31bcd855d0c..1aba422213b 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -1891,15 +1891,12 @@ impl { @@ -2520,12 +2517,11 @@ mod tests { use crate::prelude::*; use crate::sync::{Arc, Mutex}; - use core::convert::Infallible; - use core::sync::atomic::{AtomicBool, Ordering}; + use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; #[derive(Clone)] struct FileDescriptor { - fd: u16, + fd: u32, outbound_data: Arc>>, disconnect: Arc, } @@ -2560,24 +2556,44 @@ mod tests { struct TestCustomMessageHandler { features: InitFeatures, + peer_counter: AtomicUsize, + send_messages: Option, + } + + impl crate::ln::wire::Type for u64 { + fn type_id(&self) -> u16 { 4242 } } impl wire::CustomMessageReader for TestCustomMessageHandler { - type CustomMessage = Infallible; - fn read(&self, _: u16, _: &mut R) -> Result, msgs::DecodeError> { - Ok(None) + type CustomMessage = u64; + fn read(&self, msg_type: u16, reader: &mut R) -> Result, msgs::DecodeError> { + assert!(self.send_messages.is_some()); + assert_eq!(msg_type, 4242); + let mut msg = [0u8; 8]; + reader.read_exact(&mut msg).unwrap(); + Ok(Some(u64::from_be_bytes(msg))) } } impl CustomMessageHandler for TestCustomMessageHandler { - fn handle_custom_message(&self, _: Infallible, _: &PublicKey) -> Result<(), LightningError> { - unreachable!(); + fn handle_custom_message(&self, msg: u64, _: &PublicKey) -> Result<(), LightningError> { + assert_eq!(self.peer_counter.load(Ordering::Acquire) as u64, msg); + Ok(()) } - fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { Vec::new() } + fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { + if let Some(peer_node_id) = &self.send_messages { + vec![(*peer_node_id, self.peer_counter.load(Ordering::Acquire) as u64); 1000] + } else { Vec::new() } + } - fn peer_disconnected(&self, _: &PublicKey) {} - fn peer_connected(&self, _: &PublicKey, _: &msgs::Init, _: bool) -> Result<(), ()> { Ok(()) } + fn peer_disconnected(&self, _: &PublicKey) { + self.peer_counter.fetch_sub(1, Ordering::AcqRel); + } + fn peer_connected(&self, _: &PublicKey, _: &msgs::Init, _: bool) -> Result<(), ()> { + self.peer_counter.fetch_add(2, Ordering::AcqRel); + Ok(()) + } fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() } @@ -2600,7 +2616,9 @@ mod tests { chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)), logger: test_utils::TestLogger::new(), routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler { features }, + custom_handler: TestCustomMessageHandler { + features, peer_counter: AtomicUsize::new(0), send_messages: None, + }, node_signer: test_utils::TestNodeSigner::new(node_secret), } ); @@ -2623,7 +2641,9 @@ mod tests { chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)), logger: test_utils::TestLogger::new(), routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler { features }, + custom_handler: TestCustomMessageHandler { + features, peer_counter: AtomicUsize::new(0), send_messages: None, + }, node_signer: test_utils::TestNodeSigner::new(node_secret), } ); @@ -2643,7 +2663,9 @@ mod tests { chan_handler: test_utils::TestChannelMessageHandler::new(network), logger: test_utils::TestLogger::new(), routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler { features }, + custom_handler: TestCustomMessageHandler { + features, peer_counter: AtomicUsize::new(0), send_messages: None, + }, node_signer: test_utils::TestNodeSigner::new(node_secret), } ); @@ -3191,4 +3213,100 @@ mod tests { thread_c.join().unwrap(); assert!(cfg[0].chan_handler.message_fetch_counter.load(Ordering::Acquire) >= 1); } + + #[test] + #[cfg(feature = "std")] + fn test_rapid_connect_events_order_multithreaded() { + // Previously, outbound messages held in `process_events` could race with peer + // disconnection, allowing a message intended for a peer before disconnection to be sent + // to the same peer after disconnection. Here we stress the handling of such messages by + // connecting two peers repeatedly in a loop with a `CustomMessageHandler` set to stream + // custom messages with a "connection id" to each other. That "connection id" (just the + // number of reconnections seen) should always line up across both peers, which we assert + // in the message handler. + let mut cfg = create_peermgr_cfgs(2); + cfg[0].custom_handler.send_messages = + Some(cfg[1].node_signer.get_node_id(Recipient::Node).unwrap()); + cfg[1].custom_handler.send_messages = + Some(cfg[1].node_signer.get_node_id(Recipient::Node).unwrap()); + let cfg = Arc::new(cfg); + // Until we have std::thread::scoped we have to unsafe { turn off the borrow checker }. + let mut peers = create_network(2, unsafe { &*(&*cfg as *const _) as &'static _ }); + let peer_a = Arc::new(peers.pop().unwrap()); + let peer_b = Arc::new(peers.pop().unwrap()); + + let exit_flag = Arc::new(AtomicBool::new(false)); + macro_rules! spawn_thread { ($id: expr) => { { + let thread_peer_a = Arc::clone(&peer_a); + let thread_peer_b = Arc::clone(&peer_b); + let thread_exit = Arc::clone(&exit_flag); + std::thread::spawn(move || { + let id_a = thread_peer_a.node_signer.get_node_id(Recipient::Node).unwrap(); + let mut fd_a = FileDescriptor { + fd: $id, outbound_data: Arc::new(Mutex::new(Vec::new())), + disconnect: Arc::new(AtomicBool::new(false)), + }; + let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; + let mut fd_b = FileDescriptor { + fd: $id, outbound_data: Arc::new(Mutex::new(Vec::new())), + disconnect: Arc::new(AtomicBool::new(false)), + }; + let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; + let initial_data = thread_peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); + thread_peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap(); + if thread_peer_a.read_event(&mut fd_a, &initial_data).is_err() { + thread_peer_b.socket_disconnected(&fd_b); + return; + } + + loop { + if thread_exit.load(Ordering::Relaxed) { + thread_peer_a.socket_disconnected(&fd_a); + thread_peer_b.socket_disconnected(&fd_b); + return; + } + if fd_a.disconnect.load(Ordering::Relaxed) { return; } + if fd_b.disconnect.load(Ordering::Relaxed) { return; } + + let data_a = fd_a.outbound_data.lock().unwrap().split_off(0); + if !data_a.is_empty() { + if thread_peer_b.read_event(&mut fd_b, &data_a).is_err() { + thread_peer_a.socket_disconnected(&fd_a); + return; + } + } + + let data_b = fd_b.outbound_data.lock().unwrap().split_off(0); + if !data_b.is_empty() { + if thread_peer_a.read_event(&mut fd_a, &data_b).is_err() { + thread_peer_b.socket_disconnected(&fd_b); + return; + } + } + } + }) + } } } + + let mut threads = Vec::new(); + { + let thread_peer_a = Arc::clone(&peer_a); + let thread_peer_b = Arc::clone(&peer_b); + let thread_exit = Arc::clone(&exit_flag); + threads.push(std::thread::spawn(move || { + while !thread_exit.load(Ordering::Relaxed) { + thread_peer_a.process_events(); + thread_peer_b.process_events(); + } + })); + } + for i in 0..1000 { + threads.push(spawn_thread!(i)); + } + exit_flag.store(true, Ordering::Relaxed); + for thread in threads { + thread.join().unwrap(); + } + assert_eq!(peer_a.peers.read().unwrap().len(), 0); + assert_eq!(peer_b.peers.read().unwrap().len(), 0); + } }