diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 0e0b6e6eea..c2c2d4e52a 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -28,7 +28,7 @@ //! [DhtRequest]: ./enum.DhtRequest.html use crate::{ - broadcast_strategy::BroadcastStrategy, + broadcast_strategy::{BroadcastClosestRequest, BroadcastStrategy}, dedup::DedupCacheDatabase, discovery::DhtDiscoveryError, outbound::{DhtOutboundError, OutboundMessageRequester, SendMessageParams}, @@ -416,43 +416,19 @@ impl DhtActor { .await?; Ok(peers.into_iter().map(|p| p.peer_node_id().clone()).collect()) }, - Closest(closest_request) => { - let connections = connectivity - .select_connections(ConnectivitySelection::closest_to( - closest_request.node_id.clone(), - config.broadcast_factor, - closest_request.excluded_peers.clone(), - )) - .await?; - - let mut candidates = connections - .iter() - .map(|conn| conn.peer_node_id()) - .cloned() - .collect::>(); - - if !closest_request.connected_only { - let excluded = closest_request - .excluded_peers - .iter() - .chain(candidates.iter()) - .cloned() - .collect::>(); - // If we don't have enough connections, let's select some more disconnected peers (at least 2) - let n = cmp::max(config.broadcast_factor.saturating_sub(candidates.len()), 2); - let additional = Self::select_closest_peers_for_propagation( - &peer_manager, - &closest_request.node_id, - n, - &excluded, - PeerFeatures::MESSAGE_PROPAGATION, - ) - .await?; - - candidates.extend(additional); + ClosestNodes(closest_request) => { + Self::select_closest_node_connected(closest_request, config, connectivity, peer_manager).await + }, + DirectOrClosestNodes(closest_request) => { + // First check if a direct connection exists + if connectivity + .get_connection(closest_request.node_id.clone()) + .await? + .is_some() + { + return Ok(vec![closest_request.node_id.clone()]); } - - Ok(candidates) + Self::select_closest_node_connected(closest_request, config, connectivity, peer_manager).await }, Random(n, excluded) => { // Send to a random set of peers of size n that are Communication Nodes @@ -659,6 +635,50 @@ impl DhtActor { Ok(peers.into_iter().map(|p| p.node_id).collect()) } + + async fn select_closest_node_connected( + closest_request: Box, + config: DhtConfig, + mut connectivity: ConnectivityRequester, + peer_manager: Arc, + ) -> Result, DhtActorError> { + let connections = connectivity + .select_connections(ConnectivitySelection::closest_to( + closest_request.node_id.clone(), + config.broadcast_factor, + closest_request.excluded_peers.clone(), + )) + .await?; + + let mut candidates = connections + .iter() + .map(|conn| conn.peer_node_id()) + .cloned() + .collect::>(); + + if !closest_request.connected_only { + let excluded = closest_request + .excluded_peers + .iter() + .chain(candidates.iter()) + .cloned() + .collect::>(); + // If we don't have enough connections, let's select some more disconnected peers (at least 2) + let n = cmp::max(config.broadcast_factor.saturating_sub(candidates.len()), 2); + let additional = Self::select_closest_peers_for_propagation( + &peer_manager, + &closest_request.node_id, + n, + &excluded, + PeerFeatures::MESSAGE_PROPAGATION, + ) + .await?; + + candidates.extend(additional); + } + + Ok(candidates) + } } #[cfg(test)] @@ -888,6 +908,7 @@ mod test { connectivity_manager_mock_state .set_selected_connections(vec![conn_out.clone()]) .await; + let peers = requester .select_peers(BroadcastStrategy::Broadcast(Vec::new())) .await @@ -915,7 +936,29 @@ mod test { connected_only: false, }); let peers = requester - .select_peers(BroadcastStrategy::Closest(send_request)) + .select_peers(BroadcastStrategy::ClosestNodes(send_request)) + .await + .unwrap(); + assert_eq!(peers.len(), 2); + + let send_request = Box::new(BroadcastClosestRequest { + node_id: node_identity.node_id().clone(), + excluded_peers: vec![], + connected_only: false, + }); + let peers = requester + .select_peers(BroadcastStrategy::DirectOrClosestNodes(send_request)) + .await + .unwrap(); + assert_eq!(peers.len(), 1); + + let send_request = Box::new(BroadcastClosestRequest { + node_id: client_node_identity.node_id().clone(), + excluded_peers: vec![], + connected_only: false, + }); + let peers = requester + .select_peers(BroadcastStrategy::DirectOrClosestNodes(send_request)) .await .unwrap(); assert_eq!(peers.len(), 2); diff --git a/comms/dht/src/broadcast_strategy.rs b/comms/dht/src/broadcast_strategy.rs index 3e1b356067..9077cc3a58 100644 --- a/comms/dht/src/broadcast_strategy.rs +++ b/comms/dht/src/broadcast_strategy.rs @@ -57,7 +57,9 @@ pub enum BroadcastStrategy { /// Send to a random set of peers of size n that are Communication Nodes, excluding the given node IDs Random(usize, Vec), /// Send to all n nearest Communication Nodes according to the given BroadcastClosestRequest - Closest(Box), + ClosestNodes(Box), + /// Send directly to destination if connected but otherwise send to all n nearest Communication Nodes + DirectOrClosestNodes(Box), Broadcast(Vec), /// Propagate to a set of closest neighbours and random peers Propagate(NodeDestination, Vec), @@ -70,7 +72,8 @@ impl fmt::Display for BroadcastStrategy { DirectPublicKey(pk) => write!(f, "DirectPublicKey({})", pk), DirectNodeId(node_id) => write!(f, "DirectNodeId({})", node_id), Flood(excluded) => write!(f, "Flood({} excluded)", excluded.len()), - Closest(request) => write!(f, "Closest({})", request), + ClosestNodes(request) => write!(f, "ClosestNodes({})", request), + DirectOrClosestNodes(request) => write!(f, "DirectOrClosestNodes({})", request), Random(n, excluded) => write!(f, "Random({}, {} excluded)", n, excluded.len()), Broadcast(excluded) => write!(f, "Broadcast({} excluded)", excluded.len()), Propagate(destination, excluded) => write!(f, "Propagate({}, {} excluded)", destination, excluded.len(),), @@ -79,13 +82,18 @@ impl fmt::Display for BroadcastStrategy { } impl BroadcastStrategy { - /// Returns true if this strategy will send multiple messages, otherwise false - pub fn is_multi_message(&self) -> bool { + /// Returns true if this strategy will send multiple indirect messages, otherwise false + pub fn is_multi_message(&self, chosen_peers: &[NodeId]) -> bool { use BroadcastStrategy::*; - matches!( - self, - Closest(_) | Flood(_) | Broadcast(_) | Random(_, _) | Propagate(_, _) - ) + + match self { + DirectOrClosestNodes(strategy) => { + // Testing if there is a single chosen peer and it is the target NodeId + chosen_peers.len() == 1 && chosen_peers.first() == Some(&strategy.node_id) + }, + ClosestNodes(_) | Broadcast(_) | Propagate(_, _) | Flood(_) | Random(_, _) => true, + _ => false, + } } pub fn is_direct(&self) -> bool { @@ -129,7 +137,7 @@ mod test { assert!(!BroadcastStrategy::Broadcast(Default::default()).is_direct()); assert!(!BroadcastStrategy::Propagate(Default::default(), Default::default()).is_direct(),); assert!(!BroadcastStrategy::Flood(Default::default()).is_direct()); - assert!(!BroadcastStrategy::Closest(Box::new(BroadcastClosestRequest { + assert!(!BroadcastStrategy::ClosestNodes(Box::new(BroadcastClosestRequest { node_id: NodeId::default(), excluded_peers: Default::default(), connected_only: false @@ -152,7 +160,7 @@ mod test { assert!(BroadcastStrategy::Flood(Default::default()) .direct_public_key() .is_none()); - assert!(BroadcastStrategy::Closest(Box::new(BroadcastClosestRequest { + assert!(BroadcastStrategy::ClosestNodes(Box::new(BroadcastClosestRequest { node_id: NodeId::default(), excluded_peers: Default::default(), connected_only: false @@ -174,7 +182,7 @@ mod test { .direct_node_id() .is_none()); assert!(BroadcastStrategy::Flood(Default::default()).direct_node_id().is_none()); - assert!(BroadcastStrategy::Closest(Box::new(BroadcastClosestRequest { + assert!(BroadcastStrategy::ClosestNodes(Box::new(BroadcastClosestRequest { node_id: NodeId::default(), excluded_peers: Default::default(), connected_only: false diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 90fc9b8b72..0612445dca 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -66,11 +66,6 @@ pub struct DhtConfig { pub saf_max_message_size: usize, /// When true, store and forward messages are requested from peers on connect (Default: true) pub saf_auto_request: bool, - /// The minimum period used to request SAF messages from a peer. When requesting SAF messages, - /// it will request messages since the DHT last went offline, but this may be a small amount of - /// time, so `minimum_request_period` can be used so that messages aren't missed. - /// Default: 3 days - pub saf_minimum_request_period: Duration, /// The max capacity of the message hash cache /// Default: 2,500 pub dedup_cache_capacity: usize, @@ -154,7 +149,6 @@ impl Default for DhtConfig { saf_high_priority_msg_storage_ttl: Duration::from_secs(3 * 24 * 60 * 60), // 3 days saf_auto_request: true, saf_max_message_size: 512 * 1024, - saf_minimum_request_period: Duration::from_secs(3 * 24 * 60 * 60), // 3 days dedup_cache_capacity: 2_500, dedup_cache_trim_interval: Duration::from_secs(5 * 60), database_url: DbConnectionUrl::Memory, diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index a3b122f8ab..0aa9fab611 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -268,7 +268,7 @@ where S: Service is_discovery_enabled, ); - let is_broadcast = broadcast_strategy.is_multi_message(); + let is_broadcast = broadcast_strategy.is_multi_message(&peers); // Discovery is required if: // - Discovery is enabled for this request diff --git a/comms/dht/src/outbound/message_params.rs b/comms/dht/src/outbound/message_params.rs index ffc463771a..0ad00bbc4e 100644 --- a/comms/dht/src/outbound/message_params.rs +++ b/comms/dht/src/outbound/message_params.rs @@ -116,7 +116,7 @@ impl SendMessageParams { /// `node_id` - Select the closest known peers to this `NodeId` /// `excluded_peers` - vector of `NodeId`s to exclude from broadcast. pub fn closest(&mut self, node_id: NodeId, excluded_peers: Vec) -> &mut Self { - self.params_mut().broadcast_strategy = BroadcastStrategy::Closest(Box::new(BroadcastClosestRequest { + self.params_mut().broadcast_strategy = BroadcastStrategy::ClosestNodes(Box::new(BroadcastClosestRequest { excluded_peers, node_id, connected_only: false, @@ -124,10 +124,10 @@ impl SendMessageParams { self } - /// Set broadcast_strategy to Closest.`excluded_peers` are excluded. Only peers that are currently connected will be - /// included. + /// Set broadcast_strategy to ClosestNodes.`excluded_peers` are excluded. Only peers that are currently connected + /// will be included. pub fn closest_connected(&mut self, node_id: NodeId, excluded_peers: Vec) -> &mut Self { - self.params_mut().broadcast_strategy = BroadcastStrategy::Closest(Box::new(BroadcastClosestRequest { + self.params_mut().broadcast_strategy = BroadcastStrategy::ClosestNodes(Box::new(BroadcastClosestRequest { excluded_peers, node_id, connected_only: true, @@ -135,6 +135,18 @@ impl SendMessageParams { self } + /// Set broadcast_strategy to DirectOrClosestNodes.`excluded_peers` are excluded. Only peers that are currently + /// connected will be included. + pub fn direct_or_closest_connected(&mut self, node_id: NodeId, excluded_peers: Vec) -> &mut Self { + self.params_mut().broadcast_strategy = + BroadcastStrategy::DirectOrClosestNodes(Box::new(BroadcastClosestRequest { + excluded_peers, + node_id, + connected_only: true, + })); + self + } + /// Set broadcast_strategy to Neighbours. `excluded_peers` are excluded. Only Peers that have /// `PeerFeatures::MESSAGE_PROPAGATION` are included. pub fn broadcast(&mut self, excluded_peers: Vec) -> &mut Self { diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index 6cf4b83e40..f5c3f30665 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -205,7 +205,7 @@ impl OutboundServiceMock { }, }; }, - BroadcastStrategy::Closest(_) => { + BroadcastStrategy::ClosestNodes(_) => { if behaviour.broadcast == ResponseType::Queued { let (response, mut inner_reply_tx) = self.add_call((*params).clone(), body); reply_tx.send(response).expect("Reply channel cancelled"); diff --git a/comms/dht/src/storage/dht_setting_entry.rs b/comms/dht/src/storage/dht_setting_entry.rs index 73cb39fe69..dd1e06597f 100644 --- a/comms/dht/src/storage/dht_setting_entry.rs +++ b/comms/dht/src/storage/dht_setting_entry.rs @@ -27,6 +27,8 @@ use std::fmt; pub enum DhtMetadataKey { /// Timestamp each time the DHT is shut down OfflineTimestamp, + /// Timestamp of the most recent SAF message received + LastSafMessageReceived, } impl fmt::Display for DhtMetadataKey { diff --git a/comms/dht/src/store_forward/database/mod.rs b/comms/dht/src/store_forward/database/mod.rs index ec6b19a42e..173d00e0ef 100644 --- a/comms/dht/src/store_forward/database/mod.rs +++ b/comms/dht/src/store_forward/database/mod.rs @@ -217,6 +217,17 @@ impl StoreAndForwardDatabase { .await } + pub(crate) async fn delete_messages_older_than(&self, since: NaiveDateTime) -> Result { + self.connection + .with_connection_async(move |conn| { + diesel::delete(stored_messages::table) + .filter(stored_messages::stored_at.lt(since)) + .execute(conn) + .map_err(Into::into) + }) + .await + } + pub(crate) async fn truncate_messages(&self, max_size: usize) -> Result { self.connection .with_connection_async(move |conn| { diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index 607dfe0fd1..95ce5e2500 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -219,7 +219,7 @@ where S: Service target: LOG_TARGET, "Forwarding SAF message directly to node: {}, Tag#{}", node_id, dht_header.message_tag ); - send_params.closest_connected(node_id.clone(), excluded_peers); + send_params.direct_or_closest_connected(node_id.clone(), excluded_peers); }, _ => { debug!( diff --git a/comms/dht/src/store_forward/message.rs b/comms/dht/src/store_forward/message.rs index d29481f3f2..85ba721934 100644 --- a/comms/dht/src/store_forward/message.rs +++ b/comms/dht/src/store_forward/message.rs @@ -52,12 +52,17 @@ impl StoredMessagesRequest { #[cfg(test)] impl StoredMessage { - pub fn new(version: u32, dht_header: crate::envelope::DhtMessageHeader, body: Vec) -> Self { + pub fn new( + version: u32, + dht_header: crate::envelope::DhtMessageHeader, + body: Vec, + stored_at: DateTime, + ) -> Self { Self { version, dht_header: Some(dht_header.into()), body, - stored_at: Some(datetime_to_timestamp(Utc::now())), + stored_at: Some(datetime_to_timestamp(stored_at)), } } } diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index e32e3f60a1..f3ba852118 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -36,8 +36,10 @@ use crate::{ StoredMessagesResponse, }, }, + storage::DhtMetadataKey, store_forward::{error::StoreAndForwardError, service::FetchStoredMessageQuery, StoreAndForwardRequester}, }; +use chrono::{DateTime, NaiveDateTime, Utc}; use digest::Digest; use futures::{channel::mpsc, future, stream, SinkExt, StreamExt}; use log::*; @@ -172,15 +174,19 @@ where S: Service // Compile a set of stored messages for the requesting peer let mut query = FetchStoredMessageQuery::new(source_pubkey, source_node_id.clone()); - if let Some(since) = retrieve_msgs.since.map(timestamp_to_datetime) { - debug!( - target: LOG_TARGET, - "Peer '{}' requested all messages since '{}'", - source_node_id.short_str(), - since - ); - query.since(since); - } + let since: Option> = match retrieve_msgs.since.map(timestamp_to_datetime) { + Some(since) => { + debug!( + target: LOG_TARGET, + "Peer '{}' requested all messages since '{}'", + source_node_id.short_str(), + since + ); + query.with_messages_since(since); + Some(since) + }, + None => None, + }; let response_types = vec![SafResponseType::ForMe]; @@ -188,7 +194,6 @@ where S: Service query.with_response_type(resp_type); let messages = self.saf_requester.fetch_messages(query.clone()).await?; - let message_ids = messages.iter().map(|msg| msg.id).collect::>(); let stored_messages = StoredMessagesResponse { messages: try_convert_all(messages)?, request_id: retrieve_msgs.request_id, @@ -201,6 +206,7 @@ where S: Service stored_messages.messages().len(), resp_type ); + match self .outbound_service .send_message_no_header( @@ -215,13 +221,15 @@ where S: Service .await { Ok(_) => { - debug!( - target: LOG_TARGET, - "Removing {} stored message(s) for peer '{}'", - message_ids.len(), - message.source_peer.node_id.short_str() - ); - self.saf_requester.remove_messages(message_ids).await?; + if let Some(threshold) = since { + debug!( + target: LOG_TARGET, + "Removing stored message(s) from before {} for peer '{}'", + threshold, + message.source_peer.node_id.short_str() + ); + self.saf_requester.remove_messages_older_than(threshold).await?; + } }, Err(err) => { error!( @@ -366,6 +374,14 @@ where S: Service return Err(StoreAndForwardError::DhtHeaderNotProvided); } + let stored_at = match message.stored_at { + None => chrono::MIN_DATETIME, + Some(t) => DateTime::from_utc( + NaiveDateTime::from_timestamp(t.seconds, t.nanos.try_into().unwrap_or(0)), + Utc, + ), + }; + let dht_header: DhtMessageHeader = message .dht_header .expect("previously checked") @@ -410,6 +426,27 @@ where S: Service DhtInboundMessage::new(MessageTag::new(), dht_header, Arc::clone(&source_peer), message.body); inbound_msg.is_saf_message = true; + let last_saf_received = self + .dht_requester + .get_metadata::>(DhtMetadataKey::LastSafMessageReceived) + .await + .ok() + .flatten() + .unwrap_or(chrono::MIN_DATETIME); + + if stored_at > last_saf_received { + if let Err(err) = self + .dht_requester + .set_metadata(DhtMetadataKey::LastSafMessageReceived, stored_at) + .await + { + warn!( + target: LOG_TARGET, + "Failed to set last SAF message received timestamp: {:?}", err + ); + } + } + Ok(DecryptedDhtMessage::succeeded( decrypted_body, authenticated_pk, @@ -515,6 +552,7 @@ mod test { use super::*; use crate::{ envelope::DhtMessageFlags, + outbound::mock::create_outbound_service_mock, proto::envelope::DhtHeader, store_forward::{message::StoredMessagePriority, StoredMessage}, test_utils::{ @@ -528,7 +566,7 @@ mod test { service_spy, }, }; - use chrono::Utc; + use chrono::{Duration as OldDuration, Utc}; use futures::channel::mpsc; use prost::Message; use std::time::Duration; @@ -536,12 +574,17 @@ mod test { use tari_crypto::tari_utilities::hex; use tari_test_utils::collect_stream; use tari_utilities::hex::Hex; - use tokio::runtime::Handle; + use tokio::{runtime::Handle, task, time::delay_for}; // TODO: unit tests for static functions (check_signature, etc) - fn make_stored_message(node_identity: &NodeIdentity, dht_header: DhtMessageHeader) -> StoredMessage { - let body = b"A".to_vec(); + fn make_stored_message( + message: String, + node_identity: &NodeIdentity, + dht_header: DhtMessageHeader, + stored_at: NaiveDateTime, + ) -> StoredMessage { + let body = message.as_bytes().to_vec(); let body_hash = hex::to_hex(&Challenge::new().chain(body.clone()).finalize()); StoredMessage { id: 1, @@ -554,19 +597,20 @@ mod test { body, is_encrypted: false, priority: StoredMessagePriority::High as i32, - stored_at: Utc::now().naive_utc(), + stored_at, body_hash, } } - #[tokio_macros::test_basic] + #[tokio_macros::test] async fn request_stored_messages() { - let rt_handle = Handle::current(); let spy = service_spy(); let (requester, mock_state) = create_store_and_forward_mock(); let peer_manager = build_peer_manager(); - let (oms_tx, mut oms_rx) = mpsc::channel(1); + let (outbound_requester, outbound_mock) = create_outbound_service_mock(10); + let oms_mock_state = outbound_mock.get_state(); + task::spawn(outbound_mock.run()); let node_identity = make_node_identity(); @@ -606,29 +650,59 @@ mod test { requester.clone(), dht_requester.clone(), peer_manager.clone(), - OutboundMessageRequester::new(oms_tx.clone()), + outbound_requester.clone(), node_identity.clone(), message.clone(), saf_response_signal_sender.clone(), ); - rt_handle.spawn(task.run()); + task::spawn(task.run()); - let (_, body) = unwrap_oms_send_msg!(oms_rx.next().await.unwrap()); - let body = body.to_vec(); + for _ in 0..6 { + if oms_mock_state.call_count() >= 1 { + break; + } + delay_for(Duration::from_secs(5)).await; + } + assert_eq!(oms_mock_state.call_count(), 1); + + let call = oms_mock_state.pop_call().unwrap(); + let body = call.1.to_vec(); let body = EnvelopeBody::decode(body.as_slice()).unwrap(); let msg = body.decode_part::(0).unwrap().unwrap(); assert_eq!(msg.messages().len(), 0); assert!(!spy.is_called()); - assert_eq!(mock_state.call_count(), 1); + // assert_eq!(mock_state.call_count(), 2); let calls = mock_state.take_calls().await; - assert!(calls[0].contains("FetchMessages")); - assert!(calls[0].contains(node_identity.public_key().to_hex().as_str())); - assert!(calls[0].contains(format!("{:?}", since).as_str())); + let fetch_call = calls.iter().find(|c| c.contains("FetchMessages")).unwrap(); + assert!(fetch_call.contains(node_identity.public_key().to_hex().as_str())); + assert!(fetch_call.contains(format!("{:?}", since).as_str())); + let msg1_time = Utc::now() + .checked_sub_signed(OldDuration::from_std(Duration::from_secs(120)).unwrap()) + .unwrap(); + let msg1 = "one".to_string(); mock_state - .add_message(make_stored_message(&node_identity, dht_header)) + .add_message(make_stored_message( + msg1.clone(), + &node_identity, + dht_header.clone(), + msg1_time.naive_utc(), + )) + .await; + + let msg2_time = Utc::now() + .checked_sub_signed(OldDuration::from_std(Duration::from_secs(30)).unwrap()) + .unwrap(); + let msg2 = "two".to_string(); + mock_state + .add_message(make_stored_message( + msg2.clone(), + &node_identity, + dht_header, + msg2_time.naive_utc(), + )) .await; // Now lets test its response where there are messages to return. @@ -638,27 +712,42 @@ mod test { requester, dht_requester, peer_manager, - OutboundMessageRequester::new(oms_tx), + outbound_requester.clone(), node_identity.clone(), message, saf_response_signal_sender, ); - rt_handle.spawn(task.run()); + task::spawn(task.run()); - let (_, body) = unwrap_oms_send_msg!(oms_rx.next().await.unwrap()); - let body = body.to_vec(); + for _ in 0..6 { + if oms_mock_state.call_count() >= 1 { + break; + } + delay_for(Duration::from_secs(5)).await; + } + assert_eq!(oms_mock_state.call_count(), 1); + let call = oms_mock_state.pop_call().unwrap(); + + let body = call.1.to_vec(); let body = EnvelopeBody::decode(body.as_slice()).unwrap(); let msg = body.decode_part::(0).unwrap().unwrap(); + assert_eq!(msg.messages().len(), 1); - assert_eq!(msg.messages()[0].body, b"A"); + assert_eq!(msg.messages()[0].body, "two".as_bytes()); assert!(!spy.is_called()); assert_eq!(mock_state.call_count(), 2); let calls = mock_state.take_calls().await; - assert!(calls[0].contains("FetchMessages")); - assert!(calls[0].contains(node_identity.public_key().to_hex().as_str())); - assert!(calls[0].contains(format!("{:?}", since).as_str())); + + let fetch_call = calls.iter().find(|c| c.contains("FetchMessages")).unwrap(); + assert!(fetch_call.contains(node_identity.public_key().to_hex().as_str())); + assert!(fetch_call.contains(format!("{:?}", since).as_str())); + + let stored_messages = mock_state.get_messages().await; + + assert!(!stored_messages.iter().any(|s| s.body == msg1.as_bytes())); + assert!(stored_messages.iter().any(|s| s.body == msg2.as_bytes())); } #[tokio_macros::test_basic] @@ -689,13 +778,23 @@ mod test { .await .unwrap(); - let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), inbound_msg_a.body); - let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, inbound_msg_b.body); + let msg1_time = Utc::now() + .checked_sub_signed(OldDuration::from_std(Duration::from_secs(60)).unwrap()) + .unwrap(); + let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), inbound_msg_a.body, msg1_time); + let msg2_time = Utc::now() + .checked_sub_signed(OldDuration::from_std(Duration::from_secs(30)).unwrap()) + .unwrap(); + let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, inbound_msg_b.body, msg2_time); + // Cleartext message let clear_msg = wrap_in_envelope_body!(b"Clear".to_vec()).to_encoded_bytes(); let clear_header = make_dht_inbound_message(&node_identity, clear_msg.clone(), DhtMessageFlags::empty(), false).dht_header; - let msg_clear = ProtoStoredMessage::new(0, clear_header, clear_msg); + let msg_clear_time = Utc::now() + .checked_sub_signed(OldDuration::from_std(Duration::from_secs(120)).unwrap()) + .unwrap(); + let msg_clear = ProtoStoredMessage::new(0, clear_header, clear_msg, msg_clear_time); let mut message = DecryptedDhtMessage::succeeded( wrap_in_envelope_body!(StoredMessagesResponse { messages: vec![msg1.clone(), msg2, msg_clear], @@ -712,15 +811,21 @@ mod test { ); message.dht_header.message_type = DhtMessageType::SafStoredMessages; - let (dht_requester, mock) = create_dht_actor_mock(1); + let (mut dht_requester, mock) = create_dht_actor_mock(1); rt_handle.spawn(mock.run()); let (saf_response_signal_sender, mut saf_response_signal_receiver) = mpsc::channel(20); + assert!(dht_requester + .get_metadata::>(DhtMetadataKey::LastSafMessageReceived) + .await + .unwrap() + .is_none()); + let task = MessageHandlerTask::new( Default::default(), spy.to_service::(), requester, - dht_requester, + dht_requester.clone(), peer_manager, OutboundMessageRequester::new(oms_tx), node_identity, @@ -746,5 +851,13 @@ mod test { timeout = Duration::from_secs(20) ); assert_eq!(signals.len(), 1); + + let last_saf_received = dht_requester + .get_metadata::>(DhtMetadataKey::LastSafMessageReceived) + .await + .unwrap() + .unwrap(); + + assert_eq!(last_saf_received, msg2_time); } } diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs index c96d4311cb..5d06d85d56 100644 --- a/comms/dht/src/store_forward/service.rs +++ b/comms/dht/src/store_forward/service.rs @@ -43,7 +43,7 @@ use futures::{ StreamExt, }; use log::*; -use std::{cmp, convert::TryFrom, sync::Arc, time::Duration}; +use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityEventRx, ConnectivityRequester}, peer_manager::{NodeId, PeerFeatures}, @@ -76,7 +76,7 @@ impl FetchStoredMessageQuery { } } - pub fn since(&mut self, since: DateTime) -> &mut Self { + pub fn with_messages_since(&mut self, since: DateTime) -> &mut Self { self.since = Some(since); self } @@ -85,6 +85,10 @@ impl FetchStoredMessageQuery { self.response_type = response_type; self } + + pub fn since(&self) -> Option> { + self.since + } } #[derive(Debug)] @@ -92,6 +96,7 @@ pub enum StoreAndForwardRequest { FetchMessages(FetchStoredMessageQuery, oneshot::Sender>>), InsertMessage(NewStoredMessage, oneshot::Sender>), RemoveMessages(Vec), + RemoveMessagesOlderThan(DateTime), SendStoreForwardRequestToPeer(Box), SendStoreForwardRequestNeighbours, } @@ -132,6 +137,14 @@ impl StoreAndForwardRequester { Ok(()) } + pub async fn remove_messages_older_than(&mut self, threshold: DateTime) -> SafResult<()> { + self.sender + .send(StoreAndForwardRequest::RemoveMessagesOlderThan(threshold)) + .await + .map_err(|_| StoreAndForwardError::RequesterChannelClosed)?; + Ok(()) + } + pub async fn request_saf_messages_from_peer(&mut self, node_id: NodeId) -> SafResult<()> { self.sender .send(StoreAndForwardRequest::SendStoreForwardRequestToPeer(Box::new(node_id))) @@ -297,6 +310,12 @@ impl StoreAndForwardService { ); } }, + RemoveMessagesOlderThan(threshold) => { + match self.database.delete_messages_older_than(threshold.naive_utc()).await { + Ok(_) => trace!(target: LOG_TARGET, "Removed messages older than {}", threshold), + Err(err) => error!(target: LOG_TARGET, "RemoveMessage failed because '{:?}'", err), + } + }, } } @@ -382,9 +401,9 @@ impl StoreAndForwardService { async fn get_saf_request(&mut self) -> SafResult { let request = self .dht_requester - .get_metadata(DhtMetadataKey::OfflineTimestamp) + .get_metadata(DhtMetadataKey::LastSafMessageReceived) .await? - .map(|t| StoredMessagesRequest::since(cmp::min(t, since_utc(self.config.saf_minimum_request_period)))) + .map(StoredMessagesRequest::since) .unwrap_or_else(StoredMessagesRequest::new); Ok(request) @@ -490,7 +509,3 @@ fn since(period: Duration) -> NaiveDateTime { .checked_sub_signed(period) .expect("period overflowed when used with checked_sub_signed") } - -fn since_utc(period: Duration) -> DateTime { - DateTime::::from_utc(since(period), Utc) -} diff --git a/comms/dht/src/test_utils/store_and_forward_mock.rs b/comms/dht/src/test_utils/store_and_forward_mock.rs index 6a623a5764..0dd464c43a 100644 --- a/comms/dht/src/test_utils/store_and_forward_mock.rs +++ b/comms/dht/src/test_utils/store_and_forward_mock.rs @@ -83,7 +83,9 @@ impl StoreAndForwardMockState { } pub async fn take_calls(&self) -> Vec { - self.calls.write().await.drain(..).collect() + let calls = self.calls.write().await.drain(..).collect(); + self.call_count.store(0, Ordering::SeqCst); + calls } } @@ -115,9 +117,16 @@ impl StoreAndForwardMock { trace!(target: LOG_TARGET, "StoreAndForwardMock received request {:?}", req); self.state.add_call(&req).await; match req { - FetchMessages(_, reply_tx) => { + FetchMessages(request, reply_tx) => { + let since = request.since().unwrap(); + let msgs = self.state.stored_messages.read().await; - let _ = reply_tx.send(Ok(msgs.clone())); + + let _ = reply_tx.send(Ok(msgs + .clone() + .drain(..) + .filter(|m| m.stored_at >= since.naive_utc()) + .collect())); }, InsertMessage(msg, reply_tx) => { self.state.stored_messages.write().await.push(StoredMessage { @@ -143,6 +152,13 @@ impl StoreAndForwardMock { }, SendStoreForwardRequestToPeer(_) => {}, SendStoreForwardRequestNeighbours => {}, + RemoveMessagesOlderThan(threshold) => { + self.state + .stored_messages + .write() + .await + .retain(|msg| msg.stored_at >= threshold.naive_utc()); + }, } } }