From fb0869a032393c614399a2aa1c0603ab44c7774f Mon Sep 17 00:00:00 2001 From: Han Xu Date: Sun, 3 Mar 2024 22:50:51 -0800 Subject: [PATCH] add refresh for SRV records and send ServiceRemoved for expired SRV records - Added logic to handle TTL = 0 in responses. - Only apply cache-flush bit if a new record (i.e. rdata changed). --- src/dns_parser.rs | 87 ++++++++++- src/service_daemon.rs | 325 +++++++++++++++++++++++++++++------------- src/service_info.rs | 6 + 3 files changed, 312 insertions(+), 106 deletions(-) diff --git a/src/dns_parser.rs b/src/dns_parser.rs index df86d0d..9325430 100644 --- a/src/dns_parser.rs +++ b/src/dns_parser.rs @@ -103,8 +103,14 @@ pub struct DnsRecord { impl DnsRecord { fn new(name: &str, ty: u16, class: u16, ttl: u32) -> Self { let created = current_time_millis(); + + // From RFC 6762 section 5.2: + // "... The querier should plan to issue a query at 80% of the record + // lifetime, and then if no answer is received, at 85%, 90%, and 95%." let refresh = get_expiration_time(created, ttl, 80); + let expires = get_expiration_time(created, ttl, 100); + Self { entry: DnsEntry::new(name.to_string(), ty, class), ttl, @@ -136,6 +142,36 @@ impl DnsRecord { self.refresh = get_expiration_time(self.created, self.ttl, 100); } + /// Returns if this record is due for refresh. If yes, `refresh` time is updated. + pub(crate) fn refresh_maybe(&mut self, now: u64) -> bool { + if self.is_expired(now) || !self.refresh_due(now) { + return false; + } + + debug!( + "{} qtype {} is due to refresh", + &self.entry.name, self.entry.ty + ); + + // From RFC 6762 section 5.2: + // "... The querier should plan to issue a query at 80% of the record + // lifetime, and then if no answer is received, at 85%, 90%, and 95%." + // + // If the answer is received in time, 'refresh' will be reset outside + // this function, back to 80% of the new TTL. + if self.refresh == get_expiration_time(self.created, self.ttl, 80) { + self.refresh = get_expiration_time(self.created, self.ttl, 85); + } else if self.refresh == get_expiration_time(self.created, self.ttl, 85) { + self.refresh = get_expiration_time(self.created, self.ttl, 90); + } else if self.refresh == get_expiration_time(self.created, self.ttl, 90) { + self.refresh = get_expiration_time(self.created, self.ttl, 95); + } else { + self.refresh_no_more(); + } + + true + } + /// Returns the remaining TTL in seconds fn get_remaining_ttl(&self, now: u64) -> u32 { let remaining_millis = get_expiration_time(self.created, self.ttl, 100) - now; @@ -190,6 +226,8 @@ pub trait DnsRecordExt: fmt::Debug { self.get_record().entry.ty } + /// Resets TTL using `other` record. + /// `self.refresh` and `self.expires` are also reset. fn reset_ttl(&mut self, other: &dyn DnsRecordExt) { self.get_record_mut().reset_ttl(other.get_record()); } @@ -1084,7 +1122,16 @@ impl DnsIncoming { let ty = u16_from_be_slice(&slice[..2]); let class = u16_from_be_slice(&slice[2..4]); - let ttl = u32_from_be_slice(&slice[4..8]); + let mut ttl = u32_from_be_slice(&slice[4..8]); + if ttl == 0 && self.is_response() { + // RFC 6762 section 10.1: + // "...Queriers receiving a Multicast DNS response with a TTL of zero SHOULD + // NOT immediately delete the record from the cache, but instead record + // a TTL of 1 and then delete the record one second later." + // See https://datatracker.ietf.org/doc/html/rfc6762#section-10.1 + + ttl = 1; + } let length = u16_from_be_slice(&slice[8..10]) as usize; self.offset += 10; let next_offset = self.offset + length; @@ -1354,19 +1401,22 @@ const fn u32_from_be_slice(s: &[u8]) -> u32 { u32::from_be_bytes(u8_array) } -/// Returns the time in millis at which this record will have expired +/// Returns the UNIX time in millis at which this record will have expired /// by a certain percentage. const fn get_expiration_time(created: u64, ttl: u32, percent: u32) -> u64 { + // 'created' is in millis, 'ttl' is in seconds, hence: + // ttl * 1000 * (percent / 100) => ttl * percent * 10 created + (ttl * percent * 10) as u64 } #[cfg(test)] mod tests { - use crate::dns_parser::{TYPE_A, TYPE_AAAA}; + use crate::dns_parser::get_expiration_time; use super::{ - DnsIncoming, DnsNSec, DnsOutgoing, DnsSrv, CLASS_CACHE_FLUSH, CLASS_IN, FLAGS_QR_QUERY, - FLAGS_QR_RESPONSE, TYPE_PTR, + current_time_millis, DnsIncoming, DnsNSec, DnsOutgoing, DnsRecordExt, DnsSrv, + CLASS_CACHE_FLUSH, CLASS_IN, FLAGS_QR_QUERY, FLAGS_QR_RESPONSE, TYPE_A, TYPE_AAAA, + TYPE_PTR, }; #[test] @@ -1453,4 +1503,31 @@ mod tests { assert_eq!(absent_types[0], TYPE_A); assert_eq!(absent_types[1], TYPE_AAAA); } + + #[test] + fn test_refresh_maybe() { + let name = "test_refresh._udp.local."; + let ttl = 2; + let hostname = "instance1.local."; + let mut srv = DnsSrv::new(name, CLASS_IN, ttl, 0, 0, 0, hostname.to_string()); + + // refresh is not due yet. + let now = current_time_millis(); + let refreshed = srv.get_record_mut().refresh_maybe(now); + assert!(!refreshed); + + // sleep for 80 percent of TTL in millis to reach "refresh" time. + let sleep_in_mills = (ttl * 80 * 10) as u64; + std::thread::sleep(std::time::Duration::from_millis(sleep_in_mills)); + + // refresh is due. + let now = current_time_millis(); + let refreshed = srv.get_record_mut().refresh_maybe(now); + assert!(refreshed); + + // refresh time is updated. + let dns_record = srv.get_record(); + let new_refresh = get_expiration_time(dns_record.get_created(), dns_record.ttl, 85); + assert_eq!(new_refresh, dns_record.get_refresh_time()); + } } diff --git a/src/service_daemon.rs b/src/service_daemon.rs index 0a27221..9acbc77 100644 --- a/src/service_daemon.rs +++ b/src/service_daemon.rs @@ -574,37 +574,23 @@ impl ServiceDaemon { // check and evict expired records in our cache let now = current_time_millis(); - let expired_records = zc.cache.evict_expired(now); // Notify service listeners about the expired records. - for expired in &expired_records { - if let Some(dns_ptr) = expired.any().downcast_ref::() { - let ty_domain = dns_ptr.get_name(); - call_service_listener( - &zc.service_queriers, - ty_domain, - ServiceEvent::ServiceRemoved(ty_domain.to_string(), dns_ptr.alias.clone()), - ); - } - } + let expired_serives = zc.cache.evict_expired_services(now); + zc.notify_service_removal(expired_serives); // Notify hostname listeners about the expired records. - expired_records - .iter() - .filter_map(|record| record.any().downcast_ref::()) - .map(|record| record.get_name().to_owned()) - .collect::>() - .iter() - .for_each(|hostname| { - call_hostname_resolution_listener( - &zc.hostname_resolvers, - hostname, - HostnameResolutionEvent::AddressesRemoved( - hostname.to_string(), - zc.cache.get_addresses_for_host(hostname), - ), - ) - }); + let expired_addrs = zc.cache.evict_expired_addr(now); + expired_addrs.iter().for_each(|hostname| { + call_hostname_resolution_listener( + &zc.hostname_resolvers, + hostname, + HostnameResolutionEvent::AddressesRemoved( + hostname.to_string(), + zc.cache.get_addresses_for_host(hostname), + ), + ) + }); // check IP changes. if now > next_ip_check { @@ -1592,7 +1578,7 @@ impl Zeroconf { } }; - debug!("received {} bytes", sz); + debug!("received {} bytes from IP: {}", sz, ip); // If sz is 0, it means sock reached End-of-File. if sz == 0 { @@ -1836,34 +1822,40 @@ impl Zeroconf { let mut changes = Vec::new(); let mut timers = Vec::new(); for record in msg.answers { - if let Some((dns_record, true)) = self.cache.add_or_update(record) { - timers.push(dns_record.get_record().get_expire_time()); - - let ty = dns_record.get_type(); - let name = dns_record.get_name(); - if ty == TYPE_PTR { - if self.service_queriers.contains_key(name) { - timers.push(dns_record.get_record().get_refresh_time()); - } + match self.cache.add_or_update(record) { + Some((dns_record, true)) => { + timers.push(dns_record.get_record().get_expire_time()); + + let ty = dns_record.get_type(); + let name = dns_record.get_name(); + if ty == TYPE_PTR { + if self.service_queriers.contains_key(name) { + timers.push(dns_record.get_record().get_refresh_time()); + } - // send ServiceFound - if let Some(dns_ptr) = dns_record.any().downcast_ref::() { - call_service_listener( - &self.service_queriers, - name, - ServiceEvent::ServiceFound(name.to_string(), dns_ptr.alias.clone()), - ); + // send ServiceFound + if let Some(dns_ptr) = dns_record.any().downcast_ref::() { + call_service_listener( + &self.service_queriers, + name, + ServiceEvent::ServiceFound(name.to_string(), dns_ptr.alias.clone()), + ); + changes.push(InstanceChange { + ty, + name: dns_ptr.alias.clone(), + }); + } + } else { changes.push(InstanceChange { ty, - name: dns_ptr.alias.clone(), + name: name.to_string(), }); } - } else { - changes.push(InstanceChange { - ty, - name: name.to_string(), - }); } + Some((dns_record, false)) => { + timers.push(dns_record.get_record().get_expire_time()); + } + _ => {} } } @@ -2116,6 +2108,24 @@ impl Zeroconf { self.retransmissions.push(ReRun { next_time, command }); self.add_timer(next_time); } + + /// Sends service removal event to listeners for expired service records. + fn notify_service_removal(&self, expired: HashMap>) { + for (ty_domain, sender) in self.service_queriers.iter() { + if let Some(instances) = expired.get(ty_domain) { + for instance_name in instances { + let event = ServiceEvent::ServiceRemoved( + ty_domain.to_string(), + instance_name.to_string(), + ); + match sender.send(event) { + Ok(()) => debug!("Sent ServiceRemoved to listener successfully"), + Err(e) => error!("Failed to send event: {}", e), + } + } + } + } + } } /// All possible events sent to the client from the daemon @@ -2330,21 +2340,6 @@ impl DnsCache { _ => return None, }; - if incoming.get_cache_flush() { - // Mark all existing records of this type as expired, and prepend the new record. - let now = current_time_millis(); - record_vec.iter_mut().for_each(|r| { - // When cache flush is asked, we set expire date to 1 second in the future - // if created more than 1 second ago - // Ref: RFC 6762 Section 10.2 - if incoming.get_class() == r.get_class() && now > r.get_created() + 1000 { - r.set_expire(now + 1000); - } - }); - record_vec.insert(0, incoming); - return Some((record_vec.first().unwrap(), true)); - } - // update TTL for existing record or create a new record. let (idx, updated) = match record_vec .iter_mut() @@ -2356,6 +2351,20 @@ impl DnsCache { (i, false) } None => { + // Only apply CACHE_FLUSH bit if the record data has changed, i.e. a new record. + if incoming.get_cache_flush() { + // Mark all existing records of this type as expired, and prepend the new record. + let now = current_time_millis(); + + record_vec.iter_mut().for_each(|r| { + // When cache flush is asked, we set expire date to 1 second in the future + // only if created more than 1 second ago. + // Ref: RFC 6762 Section 10.2 + if incoming.get_class() == r.get_class() && now > r.get_created() + 1000 { + r.set_expire(now + 1000); + } + }); + } record_vec.insert(0, incoming); // A new record. (0, true) } @@ -2386,29 +2395,77 @@ impl DnsCache { found } - /// Iterate all records and remove ones that expired, allowing - /// a function `f` to react with the expired ones. - fn evict_expired(&mut self, now: u64) -> Vec { - self.ptr - .values_mut() - .chain(self.srv.values_mut()) - .chain(self.txt.values_mut()) - .chain(self.addr.values_mut()) - .flat_map(|record_boxes| { - let mut removed = Vec::new(); - - // NOTE: replacement for `extract_if`: https://github.com/rust-lang/rust/issues/43244 - let mut i = 0; - while i < record_boxes.len() { - if record_boxes[i].get_record().is_expired(now) { - removed.push(record_boxes.remove(i)); - } else { - i += 1; - } + /// Iterates all ADDR records and remove ones that expired. + /// Returns the record names expired and removed from the cache. + fn evict_expired_addr(&mut self, now: u64) -> Vec { + let mut removed = Vec::new(); + + for records in self.addr.values_mut() { + records.retain(|addr| { + let expired = addr.get_record().is_expired(now); + if expired { + removed.push(addr.get_name().to_string()); } - removed + !expired }) - .collect() + } + + removed + } + + /// Evicts expired PTR and SRV, TXT records for each ty_domain in the cache, and + /// returns the set of expired instance names for each ty_domain. + /// + /// An instance in the returned set indicates its PTR and/or SRV record has expired. + fn evict_expired_services(&mut self, now: u64) -> HashMap> { + let mut expired_instances = HashMap::new(); + + // Check all ty_domain in the cache by following all PTR records, regardless + // if the ty_domain is actively queried or not. + for (ty_domain, ptr_records) in self.ptr.iter_mut() { + for ptr in ptr_records.iter() { + if let Some(dns_ptr) = ptr.any().downcast_ref::() { + let instance_name = &dns_ptr.alias; + + // evict expired SRV records of this instance + if let Some(srv_records) = self.srv.get_mut(instance_name) { + srv_records.retain(|srv| { + let expired = srv.get_record().is_expired(now); + if expired { + debug!("expired SRV: {}: {}", ty_domain, srv.get_name()); + expired_instances + .entry(ty_domain.to_string()) + .or_insert(HashSet::new()) + .insert(srv.get_name().to_string()); + } + !expired + }); + } + + // evict expired TXT records of this instance + if let Some(txt_records) = self.txt.get_mut(instance_name) { + txt_records.retain(|txt| !txt.get_record().is_expired(now)) + } + } + } + + // evict expired PTR records + ptr_records.retain(|x| { + let expired = x.get_record().is_expired(now); + if expired { + if let Some(dns_ptr) = x.any().downcast_ref::() { + debug!("expired PTR: {}: {}", ty_domain, dns_ptr.alias); + expired_instances + .entry(ty_domain.to_string()) + .or_insert(HashSet::new()) + .insert(dns_ptr.alias.clone()); + } + } + !expired + }); + } + + expired_instances } /// Returns the set of instance names that are due for refresh @@ -2419,23 +2476,34 @@ impl DnsCache { fn refresh_due_services(&mut self, ty_domain: &str) -> HashSet { let now = current_time_millis(); - self.ptr - .get_mut(ty_domain) - .into_iter() - .flatten() - .filter_map(|record| { - let rec = record.get_record_mut(); - if rec.is_expired(now) || !rec.refresh_due(now) { - return None; + let mut refresh_due = HashSet::new(); + let mut ptr_not_due = vec![]; + + // find PTR records that are due for refresh. + for record in self.ptr.get_mut(ty_domain).into_iter().flatten() { + let is_due = record.get_record_mut().refresh_maybe(now); + + if let Some(ptr) = record.any().downcast_ref::() { + let instance = ptr.alias.clone(); + if is_due { + refresh_due.insert(instance); + } else { + ptr_not_due.push(instance); } - rec.refresh_no_more(); + } + } - record - .any() - .downcast_ref::() - .map(|dns_ptr| dns_ptr.alias.clone()) - }) - .collect() + // Check SRV records. + for instance in ptr_not_due { + for record in self.srv.get_mut(&instance).into_iter().flatten() { + if record.get_record_mut().refresh_maybe(now) { + refresh_due.insert(instance); + break; + } + } + } + + refresh_due } /// Returns the set of A/AAAA records that are due for refresh for a `hostname`. @@ -2649,6 +2717,7 @@ mod tests { service_info::IntoTxtProperties, }; use std::{collections::HashMap, net::SocketAddr, net::SocketAddrV4, time::Duration}; + use test_log::test; #[test] fn test_socketaddr_print() { @@ -2939,4 +3008,58 @@ mod tests { assert!(resolved); d.shutdown().unwrap(); } + + #[test] + fn test_expired_srv() { + // construct service info + let service_type = "_expired-srv._udp.local."; + let instance = "test_instance"; + let host_name = "expired_srv_host.local."; + let mut my_service = ServiceInfo::new(service_type, instance, host_name, "", 5023, None) + .unwrap() + .enable_addr_auto(); + // let fullname = my_service.get_fullname().to_string(); + + // set SRV to expire soon. + let new_ttl = 2; // for testing only. + my_service._set_host_ttl(new_ttl); + + // register my service + let mdns_server = ServiceDaemon::new().expect("Failed to create mdns server"); + let result = mdns_server.register(my_service); + assert!(result.is_ok()); + + let mdns_client = ServiceDaemon::new().expect("Failed to create mdns client"); + let browse_chan = mdns_client.browse(service_type).unwrap(); + let timeout = Duration::from_secs(1); + let mut resolved = false; + + while let Ok(event) = browse_chan.recv_timeout(timeout) { + match event { + ServiceEvent::ServiceResolved(info) => { + resolved = true; + println!("Resolved a service of {}", &info.get_fullname()); + break; + } + _ => {} + } + } + + assert!(resolved); + + // Exit the server so that no more responses. + mdns_server.shutdown().unwrap(); + + // SRV record in the client cache will expire. + let expire_timeout = Duration::from_secs(new_ttl as u64); + while let Ok(event) = browse_chan.recv_timeout(expire_timeout) { + match event { + ServiceEvent::ServiceRemoved(service_type, full_name) => { + println!("Service removed: {}: {}", &service_type, &full_name); + break; + } + _ => {} + } + } + } } diff --git a/src/service_info.rs b/src/service_info.rs index c200a6e..3025fbc 100644 --- a/src/service_info.rs +++ b/src/service_info.rs @@ -302,6 +302,12 @@ impl ServiceInfo { pub(crate) fn set_subtype(&mut self, subtype: String) { self.sub_domain = Some(subtype); } + + /// host_ttl is for SRV and address records + /// currently only used for testing. + pub(crate) fn _set_host_ttl(&mut self, ttl: u32) { + self.host_ttl = ttl; + } } /// Removes potentially duplicated ".local." at the end of "hostname".