Skip to content

Commit

Permalink
Assert peer_{dis,}connected consistency across test handlers
Browse files Browse the repository at this point in the history
This adds a `ConnectionTracker` test util which is used across
`TestChannelMessageHandler`, `TestRoutingMessageHandler` and
`TestCustomMessageHandler`, asserting that `peer_connected` and
`peer_disconnected` methods are well-ordered. This expands test
coverage from just `TestChannelMessageHandler` to cover all test
handlers and adds some useful features which we'll use to test
the fix in the next commit.

This also adds an additional test which tests
`peer_{dis,}connected` consistency when a handler refuses a
connection by returning an `Err` from `peer_connected`.
  • Loading branch information
TheBlueMatt committed Jan 30, 2025
1 parent 07148db commit 4bc597a
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 22 deletions.
103 changes: 88 additions & 15 deletions lightning/src/ln/peer_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2867,6 +2867,16 @@ mod tests {

struct TestCustomMessageHandler {
features: InitFeatures,
conn_tracker: test_utils::ConnectionTracker,
}

impl TestCustomMessageHandler {
fn new(features: InitFeatures) -> Self {
Self {
features,
conn_tracker: test_utils::ConnectionTracker::new(),
}
}
}

impl wire::CustomMessageReader for TestCustomMessageHandler {
Expand All @@ -2883,10 +2893,13 @@ mod tests {

fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { Vec::new() }

fn peer_disconnected(&self, their_node_id: PublicKey) {
self.conn_tracker.peer_disconnected(their_node_id);
}

fn peer_disconnected(&self, _their_node_id: PublicKey) {}

fn peer_connected(&self, _their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> { Ok(()) }
fn peer_connected(&self, their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> {
self.conn_tracker.peer_connected(their_node_id)
}

fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() }

Expand All @@ -2909,7 +2922,7 @@ mod tests {
chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)),
logger: test_utils::TestLogger::with_id(i.to_string()),
routing_handler: test_utils::TestRoutingMessageHandler::new(),
custom_handler: TestCustomMessageHandler { features },
custom_handler: TestCustomMessageHandler::new(features),
node_signer: test_utils::TestNodeSigner::new(node_secret),
}
);
Expand All @@ -2932,7 +2945,7 @@ mod tests {
chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)),
logger: test_utils::TestLogger::new(),
routing_handler: test_utils::TestRoutingMessageHandler::new(),
custom_handler: TestCustomMessageHandler { features },
custom_handler: TestCustomMessageHandler::new(features),
node_signer: test_utils::TestNodeSigner::new(node_secret),
}
);
Expand All @@ -2952,7 +2965,7 @@ mod tests {
chan_handler: test_utils::TestChannelMessageHandler::new(network),
logger: test_utils::TestLogger::new(),
routing_handler: test_utils::TestRoutingMessageHandler::new(),
custom_handler: TestCustomMessageHandler { features },
custom_handler: TestCustomMessageHandler::new(features),
node_signer: test_utils::TestNodeSigner::new(node_secret),
}
);
Expand All @@ -2976,19 +2989,16 @@ mod tests {
peers
}

fn establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>) -> (FileDescriptor, FileDescriptor) {
fn try_establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>) -> (FileDescriptor, FileDescriptor, Result<bool, PeerHandleError>, Result<bool, PeerHandleError>) {
let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000};
let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001};

static FD_COUNTER: AtomicUsize = AtomicUsize::new(0);
let fd = FD_COUNTER.fetch_add(1, Ordering::Relaxed) as u16;

let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap();
let mut fd_a = FileDescriptor::new(fd);
let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000};

let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap();
let features_a = peer_a.init_features(id_b);
let features_b = peer_b.init_features(id_a);
let mut fd_b = FileDescriptor::new(fd);
let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001};

let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap();
peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap();
Expand All @@ -3000,11 +3010,30 @@ mod tests {

peer_b.process_events();
let b_data = fd_b.outbound_data.lock().unwrap().split_off(0);
assert_eq!(peer_a.read_event(&mut fd_a, &b_data).unwrap(), false);
let a_refused = peer_a.read_event(&mut fd_a, &b_data);

peer_a.process_events();
let a_data = fd_a.outbound_data.lock().unwrap().split_off(0);
assert_eq!(peer_b.read_event(&mut fd_b, &a_data).unwrap(), false);
let b_refused = peer_b.read_event(&mut fd_b, &a_data);

(fd_a, fd_b, a_refused, b_refused)
}


fn establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>) -> (FileDescriptor, FileDescriptor) {
let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000};
let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001};

let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap();
let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap();

let features_a = peer_a.init_features(id_b);
let features_b = peer_b.init_features(id_a);

let (fd_a, fd_b, a_refused, b_refused) = try_establish_connection(peer_a, peer_b);

assert_eq!(a_refused.unwrap(), false);
assert_eq!(b_refused.unwrap(), false);

assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().counterparty_node_id, id_b);
assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().socket_address, Some(addr_b));
Expand Down Expand Up @@ -3257,6 +3286,50 @@ mod tests {
assert_eq!(peers[0].peers.read().unwrap().len(), 0);
}

fn do_test_peer_connected_error_disconnects(handler: usize) {
// Test that if a message handler fails a connection in `peer_connected` we reliably
// produce `peer_disconnected` events for all other message handlers (that saw a
// corresponding `peer_connected`).
let cfgs = create_peermgr_cfgs(2);
let peers = create_network(2, &cfgs);

match handler & !1 {
0 => {
peers[handler & 1].message_handler.chan_handler.conn_tracker.fail_connections.store(true, Ordering::Release);
}
2 => {
peers[handler & 1].message_handler.route_handler.conn_tracker.fail_connections.store(true, Ordering::Release);
}
4 => {
peers[handler & 1].message_handler.custom_message_handler.conn_tracker.fail_connections.store(true, Ordering::Release);
}
_ => panic!(),
}
let (_sd1, _sd2, a_refused, b_refused) = try_establish_connection(&peers[0], &peers[1]);
if handler & 1 == 0 {
assert!(a_refused.is_err());
assert!(peers[0].list_peers().is_empty());
} else {
assert!(b_refused.is_err());
assert!(peers[1].list_peers().is_empty());
}
// At least one message handler should have seen the connection.
assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.had_peers.load(Ordering::Acquire) ||
peers[handler & 1].message_handler.route_handler.conn_tracker.had_peers.load(Ordering::Acquire) ||
peers[handler & 1].message_handler.custom_message_handler.conn_tracker.had_peers.load(Ordering::Acquire));
// And both message handlers doing tracking should see the disconnection
assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.connected_peers.lock().unwrap().is_empty());
assert!(peers[handler & 1].message_handler.route_handler.conn_tracker.connected_peers.lock().unwrap().is_empty());
assert!(peers[handler & 1].message_handler.custom_message_handler.conn_tracker.connected_peers.lock().unwrap().is_empty());
}

#[test]
fn test_peer_connected_error_disconnects() {
for i in 0..6 {
do_test_peer_connected_error_disconnects(i);
}
}

#[test]
fn test_do_attempt_write_data() {
// Create 2 peers with custom TestRoutingMessageHandlers and connect them.
Expand Down
53 changes: 46 additions & 7 deletions lightning/src/util/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -889,10 +889,45 @@ impl chaininterface::BroadcasterInterface for TestBroadcaster {
}
}

pub struct ConnectionTracker {
pub had_peers: AtomicBool,
pub connected_peers: Mutex<Vec<PublicKey>>,
pub fail_connections: AtomicBool,
}

impl ConnectionTracker {
pub fn new() -> Self {
Self {
had_peers: AtomicBool::new(false),
connected_peers: Mutex::new(Vec::new()),
fail_connections: AtomicBool::new(false),
}
}

pub fn peer_connected(&self, their_node_id: PublicKey) -> Result<(), ()> {
self.had_peers.store(true, Ordering::Release);
let mut connected_peers = self.connected_peers.lock().unwrap();
assert!(!connected_peers.contains(&their_node_id));
if self.fail_connections.load(Ordering::Acquire) {
Err(())
} else {
connected_peers.push(their_node_id);
Ok(())
}
}

pub fn peer_disconnected(&self, their_node_id: PublicKey) {
assert!(self.had_peers.load(Ordering::Acquire));
let mut connected_peers = self.connected_peers.lock().unwrap();
assert!(connected_peers.contains(&their_node_id));
connected_peers.retain(|id| *id != their_node_id);
}
}

pub struct TestChannelMessageHandler {
pub pending_events: Mutex<Vec<events::MessageSendEvent>>,
expected_recv_msgs: Mutex<Option<Vec<wire::Message<()>>>>,
connected_peers: Mutex<HashSet<PublicKey>>,
pub conn_tracker: ConnectionTracker,
chain_hash: ChainHash,
}

Expand All @@ -907,7 +942,7 @@ impl TestChannelMessageHandler {
TestChannelMessageHandler {
pending_events: Mutex::new(Vec::new()),
expected_recv_msgs: Mutex::new(None),
connected_peers: Mutex::new(new_hash_set()),
conn_tracker: ConnectionTracker::new(),
chain_hash,
}
}
Expand Down Expand Up @@ -1019,15 +1054,14 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
self.received_msg(wire::Message::ChannelReestablish(msg.clone()));
}
fn peer_disconnected(&self, their_node_id: PublicKey) {
assert!(self.connected_peers.lock().unwrap().remove(&their_node_id));
self.conn_tracker.peer_disconnected(their_node_id)
}
fn peer_connected(
&self, their_node_id: PublicKey, _msg: &msgs::Init, _inbound: bool,
) -> Result<(), ()> {
assert!(self.connected_peers.lock().unwrap().insert(their_node_id.clone()));
// Don't bother with `received_msg` for Init as its auto-generated and we don't want to
// bother re-generating the expected Init message in all tests.
Ok(())
self.conn_tracker.peer_connected(their_node_id)
}
fn handle_error(&self, _their_node_id: PublicKey, msg: &msgs::ErrorMessage) {
self.received_msg(wire::Message::Error(msg.clone()));
Expand Down Expand Up @@ -1157,6 +1191,7 @@ pub struct TestRoutingMessageHandler {
pub pending_events: Mutex<Vec<events::MessageSendEvent>>,
pub request_full_sync: AtomicBool,
pub announcement_available_for_sync: AtomicBool,
pub conn_tracker: ConnectionTracker,
}

impl TestRoutingMessageHandler {
Expand All @@ -1168,6 +1203,7 @@ impl TestRoutingMessageHandler {
pending_events,
request_full_sync: AtomicBool::new(false),
announcement_available_for_sync: AtomicBool::new(false),
conn_tracker: ConnectionTracker::new(),
}
}
}
Expand Down Expand Up @@ -1242,10 +1278,13 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
timestamp_range: u32::max_value(),
},
});
Ok(())

self.conn_tracker.peer_connected(their_node_id)
}

fn peer_disconnected(&self, _their_node_id: PublicKey) {}
fn peer_disconnected(&self, their_node_id: PublicKey) {
self.conn_tracker.peer_disconnected(their_node_id);
}

fn handle_reply_channel_range(
&self, _their_node_id: PublicKey, _msg: msgs::ReplyChannelRange,
Expand Down

0 comments on commit 4bc597a

Please sign in to comment.