Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dgw): polish scanner, fix issues #715

Merged
merged 17 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/network-scanner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ netlink-packet-route = "0.19.0"
rtnetlink = "0.14.1"

[dev-dependencies]
tokio = { version = "1.36.0", features = ["rt","macros","rt-multi-thread","tracing"] }
tokio = { version = "1.36.0", features = ["rt", "macros", "rt-multi-thread", "tracing", "signal"] }
tracing-subscriber = "0.3.18"
12 changes: 9 additions & 3 deletions crates/network-scanner/examples/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::time::Duration;

use anyhow::Context;
use network_scanner::scanner::{NetworkScanner, NetworkScannerParams};
use tokio::time::timeout;

fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::SubscriberBuilder::default()
Expand All @@ -25,16 +26,21 @@ fn main() -> anyhow::Result<()> {

mdns_query_timeout: 5 * 1000,

max_wait_time: 120 * 1000,
max_wait_time: 10 * 1000,
};
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async move {
let scanner = NetworkScanner::new(params).unwrap();
let stream = scanner.start()?;
let stream_clone = stream.clone();
let now = std::time::Instant::now();
while let Ok(Some(res)) = stream_clone
.recv_timeout(Duration::from_secs(120))
tokio::task::spawn(async move {
if tokio::signal::ctrl_c().await.is_ok() {
tracing::info!("Ctrl-C received, stopping network scan");
stream.stop();
}
});
while let Ok(Some(res)) = timeout(Duration::from_secs(120), stream_clone.recv())
.await
.with_context(|| {
tracing::error!("Failed to receive from stream");
Expand Down
1 change: 1 addition & 0 deletions crates/network-scanner/src/ip_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use anyhow::Context;
use network_interface::{Addr, NetworkInterfaceConfig, V4IfAddr};

#[derive(Debug, Clone)]
pub struct IpAddrRange {
lower: IpAddr,
upper: IpAddr,
Expand Down
48 changes: 38 additions & 10 deletions crates/network-scanner/src/mdns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,35 @@ impl MdnsDaemon {
pub fn get_service_daemon(&self) -> mdns_sd::ServiceDaemon {
self.service_daemon.clone()
}

pub fn stop(&self) {
let receiver = match self.service_daemon.shutdown() {
Ok(receiver) => receiver,
Err(e) => {
// if e is try again, we should try again, but only once
let result = if matches!(e, mdns_sd::Error::Again) {
self.service_daemon.shutdown()
} else {
Err(e)
};

let Ok(receiver) = result.inspect_err(|e| {
warn!(error = %e, "Failed to shutdown service daemon");
}) else {
return;
};

receiver
}
};

// Receive the last event (Shutdown), preventing the receiver from being dropped, avoid logging an error from the sender side(the mdns crate)
let _ = receiver.recv_timeout(std::time::Duration::from_millis(100));
irvingoujAtDevolution marked this conversation as resolved.
Show resolved Hide resolved
}
}

const SERVICE_TYPES_INTERESTED: [ServiceType; 11] = [
ServiceType::Ard,
// ARD is a variant of the RFB (VNC) protocol, so it’s not included in this list.
const SERVICE_TYPES_INTERESTED: [ServiceType; 10] = [
ServiceType::Http,
ServiceType::Https,
ServiceType::Ldap,
Expand Down Expand Up @@ -52,21 +77,25 @@ pub fn mdns_query_scan(
service_daemon.clone(),
service_name.clone(),
);
let receiver = service_daemon.browse(service_name.as_ref()).with_context(|| {
let err_msg = format!("failed to browse for service: {}", service_name);
error!(error = err_msg);
err_msg
})?;

let receiver_clone = receiver.clone();
task_manager
.with_timeout(query_duration)
.when_finish(move || {
debug!(service_name = ?service_name_clone, "Stopping browse for service");
if let Err(e) = service_daemon_clone.stop_browse(service_name_clone.as_ref()) {
warn!(error = %e, "Failed to stop browsing for service");
}
// Receive the last event (StopBrowse), preventing the receiver from being dropped,this will satisfy the sender side to avoid loging an error
let _ = receiver_clone.recv_timeout(std::time::Duration::from_millis(10));
})
.spawn(move |_| async move {
debug!(?service_name, "Starting browse for service");
let receiver = service_daemon.browse(service_name.as_ref()).with_context(|| {
let err_msg = format!("failed to browse for service: {}", service_name);
error!(error = err_msg);
err_msg
})?;

while let Ok(service_event) = receiver.recv_async().await {
debug!(?service_event);
Expand Down Expand Up @@ -129,9 +158,8 @@ impl TryFrom<&str> for ServiceType {
"_telnet._tcp" => Ok(ServiceType::Telnet),
"_ldap._tcp" => Ok(ServiceType::Ldap),
"_ldaps._tcp" => Ok(ServiceType::Ldaps),
// https://jonathanmumm.com/tech-it/mdns-bonjour-bible-common-service-strings-for-various-vendors/
// OSX Screen Sharing
"_rfb._tcp" => Ok(ServiceType::Ard),
// ARD is a variant of RFB (VNC) protocol.
"_rfb._tcp" => Ok(ServiceType::Vnc),
"_rdp._tcp" | "_rdp._udp" => Ok(ServiceType::Rdp),
_ => Err(anyhow::anyhow!("unknown protocol: {}", value)),
}
Expand Down
32 changes: 24 additions & 8 deletions crates/network-scanner/src/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ impl NetworkScanner {
}: TaskExecutionContext,
task_manager| async move {
for subnet in subnets {
debug!(broadcasting_to_subnet = ?subnet);
let (runtime, ip_sender) = (runtime.clone(), ip_sender.clone());
task_manager.spawn(move |task_manager: crate::task_utils::TaskManager| async move {
let mut receiver =
Expand All @@ -140,6 +141,7 @@ impl NetworkScanner {
}: TaskExecutionContext,
task_manager| async move {
let ip_ranges: Vec<IpAddrRange> = subnets.iter().map(|subnet| subnet.into()).collect();
debug!(netbios_query_ip_ranges = ?ip_ranges);

for ip_range in ip_ranges {
let (runtime, ip_sender, task_manager) = (runtime.clone(), ip_sender.clone(), task_manager.clone());
Expand All @@ -166,6 +168,7 @@ impl NetworkScanner {
}: TaskExecutionContext,
task_manager| async move {
let ip_ranges: Vec<IpAddrRange> = subnets.iter().map(|subnet| subnet.into()).collect();
debug!(ping_ip_ranges = ?ip_ranges);

let should_ping = move |ip: IpAddr| -> bool { !ip_cache.read().contains_key(&ip) };

Expand Down Expand Up @@ -220,18 +223,29 @@ impl NetworkScanner {
);

let TaskExecutionRunner {
context: TaskExecutionContext { port_receiver, .. },
context:
TaskExecutionContext {
port_receiver,
mdns_daemon,
..
},
task_manager,
} = task_executor;

task_manager.stop_timeout(self.max_wait_time);
let scanner_stream = Arc::new(NetworkScannerStream {
result_receiver: port_receiver,
task_manager,
mdns_daemon,
});

Ok({
Arc::new(NetworkScannerStream {
result_receiver: port_receiver,
task_manager,
})
})
let (scanner_stream_clone, max_wait_time) = (scanner_stream.clone(), self.max_wait_time);

tokio::spawn(async move {
tokio::time::sleep(max_wait_time).await;
scanner_stream_clone.stop();
});

Ok(scanner_stream)
}

pub fn new(
Expand Down Expand Up @@ -289,6 +303,7 @@ pub struct ScanEntry {
pub struct NetworkScannerStream {
result_receiver: Arc<Mutex<ScanEntryReceiver>>,
task_manager: TaskManager,
mdns_daemon: MdnsDaemon,
}

impl NetworkScannerStream {
Expand All @@ -305,6 +320,7 @@ impl NetworkScannerStream {

pub fn stop(self: Arc<Self>) {
self.task_manager.stop();
self.mdns_daemon.stop();
}
}

Expand Down
8 changes: 0 additions & 8 deletions crates/network-scanner/src/task_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,6 @@ impl TaskManager {
}
debug!("All tasks stopped");
}

pub(crate) fn stop_timeout(&self, timeout: Duration) {
let self_clone = self.clone();
tokio::spawn(async move {
tokio::time::sleep(timeout).await;
self_clone.stop();
});
}
}

pub(crate) struct TimeoutManager {
Expand Down
25 changes: 23 additions & 2 deletions devolutions-gateway/src/api/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ pub async fn handle_network_scan(
tokio::select! {
result = stream.recv() => {
let Some(entry) = result else {
let _ = websocket
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: axum::extract::ws::close_code::NORMAL,
reason: std::borrow::Cow::from("network scan finished successfully"),
})))
.await;

break;
};

Expand All @@ -56,6 +63,15 @@ pub async fn handle_network_scan(

if let Err(error) = websocket.send(Message::Text(response)).await {
warn!(%error, "Failed to send message");

// It is very likely that the websocket is already closed, but send it as a precaution.
let _ = websocket
.send(Message::Close(Some(axum::extract::ws::CloseFrame {
code: axum::extract::ws::close_code::ABNORMAL,
reason: std::borrow::Cow::from("network scan finished prematurely."),
})))
.await;

break;
}
},
Expand All @@ -71,9 +87,14 @@ pub async fn handle_network_scan(
}
}

info!("Network scan finished");

// Stop the network scanner, whatever the code path (error or not).
stream.stop();

// In case the websocket is not closed yet.
// If the logic above is correct, it’s not necessary.
let _ = websocket.close().await;

info!("Network scan finished");
});

Ok(res)
Expand Down
Loading