From 0d9112f04900688e731d2de7e18811402ad95322 Mon Sep 17 00:00:00 2001 From: Han Xu Date: Wed, 26 Jun 2024 10:13:11 -0700 Subject: [PATCH] Add sanity check for service type domain suffix in browse --- src/service_daemon.rs | 39 ++++++++++++++++++++++++++++++--------- tests/mdns_test.rs | 8 ++++++++ 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/src/service_daemon.rs b/src/service_daemon.rs index 8a9ef61..8ed4c8c 100644 --- a/src/service_daemon.rs +++ b/src/service_daemon.rs @@ -223,6 +223,8 @@ impl ServiceDaemon { /// Starts browsing for a specific service type. /// + /// `service_type` must end with a valid mDNS domain: '._tcp.local.' or '._udp.local.' + /// /// Returns a channel `Receiver` to receive events about the service. The caller /// can call `.recv_async().await` on this receiver to handle events in an /// async environment or call `.recv()` in a sync environment. @@ -230,6 +232,8 @@ impl ServiceDaemon { /// When a new instance is found, the daemon automatically tries to resolve, i.e. /// finding more details, i.e. SRV records and TXT records. pub fn browse(&self, service_type: &str) -> Result> { + check_domain_suffix(service_type)?; + let (resp_s, resp_r) = bounded(10); self.send_cmd(Command::Browse(service_type.to_string(), 1, resp_s))?; Ok(resp_r) @@ -2374,6 +2378,18 @@ fn check_service_name_length(ty_domain: &str, limit: u8) -> Result<()> { Ok(()) } +/// Checks if `name` ends with a valid domain: '._tcp.local.' or '._udp.local.' +fn check_domain_suffix(name: &str) -> Result<()> { + if !(name.ends_with("._tcp.local.") || name.ends_with("._udp.local.")) { + return Err(e_fmt!( + "mDNS service {} must end with '._tcp.local.' or '._udp.local.'", + name + )); + } + + Ok(()) +} + /// Validate the service name in a fully qualified name. /// /// A Full Name = .. @@ -2382,12 +2398,7 @@ fn check_service_name_length(ty_domain: &str, limit: u8) -> Result<()> { /// Note: this function does not check for the length of the service name. /// Instead `register_service` method will check the length. fn check_service_name(fullname: &str) -> Result<()> { - if !(fullname.ends_with("._tcp.local.") || fullname.ends_with("._udp.local.")) { - return Err(e_fmt!( - "Service {} must end with '._tcp.local.' or '._udp.local.'", - fullname - )); - } + check_domain_suffix(fullname)?; let remaining: Vec<&str> = fullname[..fullname.len() - DOMAIN_LEN].split('.').collect(); let name = remaining.last().ok_or_else(|| e_fmt!("No service name"))?; @@ -2529,9 +2540,9 @@ fn valid_instance_name(name: &str) -> bool { #[cfg(test)] mod tests { use super::{ - broadcast_dns_on_intf, check_service_name_length, my_ip_interfaces, new_socket_bind, - valid_instance_name, HostnameResolutionEvent, IntfSock, ServiceDaemon, ServiceEvent, - ServiceInfo, GROUP_ADDR_V4, MDNS_PORT, + broadcast_dns_on_intf, check_domain_suffix, check_service_name_length, my_ip_interfaces, + new_socket_bind, valid_instance_name, HostnameResolutionEvent, IntfSock, ServiceDaemon, + ServiceEvent, ServiceInfo, GROUP_ADDR_V4, MDNS_PORT, }; use crate::{ dns_parser::{DnsOutgoing, DnsPointer, CLASS_IN, FLAGS_AA, FLAGS_QR_RESPONSE, TYPE_PTR}, @@ -2591,6 +2602,16 @@ mod tests { } } + #[test] + fn test_check_domain_suffix() { + assert!(check_domain_suffix("_missing_dot._tcp.local").is_err()); + assert!(check_domain_suffix("_missing_bar.tcp.local.").is_err()); + assert!(check_domain_suffix("_mis_spell._tpp.local.").is_err()); + assert!(check_domain_suffix("_mis_spell._upp.local.").is_err()); + assert!(check_domain_suffix("_has_dot._tcp.local.").is_ok()); + assert!(check_domain_suffix("_goodname._udp.local.").is_ok()); + } + #[test] fn service_with_temporarily_invalidated_ptr() { // Create a daemon diff --git a/tests/mdns_test.rs b/tests/mdns_test.rs index 0409644..4a3aae2 100644 --- a/tests/mdns_test.rs +++ b/tests/mdns_test.rs @@ -1411,3 +1411,11 @@ fn test_known_answer_suppression() { println!("metrics: {:?}", &metrics); assert!(metrics["known-answer-suppression"] > 0); } + +#[test] +fn test_domain_suffix_in_browse() { + let mdns_client = ServiceDaemon::new().expect("failed to create mDNS client"); + assert!(mdns_client.browse("_service-name._tcp.local").is_err()); + assert!(mdns_client.browse("_service-name._tcp.local.").is_ok()); + mdns_client.shutdown().unwrap(); +}