Skip to content

Commit

Permalink
fix cache refresh and add metrics (#7)
Browse files Browse the repository at this point in the history
* Fixed a bug in cache refresh causing outgoing queries surge.
* Added `get_metrics` API for debugging and monitoring.
* Fixed a corner case bug in retransmissions.
  • Loading branch information
keepsimple1 authored Jan 11, 2022
1 parent afb5188 commit c5a759e
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 48 deletions.
224 changes: 176 additions & 48 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ use nix::{
use std::{
any::Any,
cmp,
collections::{BTreeMap, HashMap, HashSet},
collections::{HashMap, HashSet},
convert::TryInto,
fmt,
os::unix::io::RawFd,
Expand Down Expand Up @@ -227,6 +227,37 @@ pub enum UnregisterStatus {
NotFound,
}

/// Different counters included in the metrics.
/// Currently all counters are for outgoing packets.
#[derive(Hash, Eq, PartialEq, Clone)]
enum Counter {
Register,
RegisterResend,
Unregister,
UnregisterResend,
Browse,
Respond,
CacheRefreshQuery,
}

impl fmt::Display for Counter {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Counter::Register => write!(f, "register"),
Counter::RegisterResend => write!(f, "register-resend"),
Counter::Unregister => write!(f, "unregister"),
Counter::UnregisterResend => write!(f, "unregister-resend"),
Counter::Browse => write!(f, "browse"),
Counter::Respond => write!(f, "respond"),
Counter::CacheRefreshQuery => write!(f, "cache-refresh"),
}
}
}

/// The metrics is a HashMap of (name_key, i64_value).
/// The main purpose is to help monitoring the mDNS packet traffic.
pub type Metrics = HashMap<String, i64>;

/// A daemon thread for mDNS
///
/// This struct provides a handle and an API to the daemon. It is cloneable.
Expand Down Expand Up @@ -325,6 +356,21 @@ impl ServiceDaemon {
})
}

/// Returns a channel receiver for the metrics, e.g. input/output counters.
///
/// The metrics returned is a snapshot. Hence the caller should call
/// this method repeatedly if they want to monitor the metrics continuously.
pub fn get_metrics(&self) -> Result<Receiver<Metrics>> {
let (resp_s, resp_r) = bounded(1);
self.sender
.try_send(Command::GetMetrics(resp_s))
.map_err(|e| match e {
TrySendError::Full(_) => Error::Again,
e => e_fmt!("crossbeam::channel::try_send failed: {}", e),
})?;
Ok(resp_r)
}

/// The main event loop of the daemon thread
///
/// In each round, it will:
Expand Down Expand Up @@ -372,40 +418,27 @@ impl ServiceDaemon {
_ => {}
}

// check for repeated commands
// check for repeated commands and run them if their time is up.
let now = current_time_millis();
let keys: Vec<_> = zc.retransmissions.keys().cloned().collect();
for instant in keys.iter() {
if now >= *instant {
debug!("Execute command from planned actions");
match zc.retransmissions.remove_entry(instant) {
Some((_, cmd)) => Self::exec_command(&mut zc, cmd, true),
None => error!("missing command in planned actions"),
}
let mut i = 0;
while i < zc.retransmissions.len() {
if now >= zc.retransmissions[i].next_time {
let rerun = zc.retransmissions.remove(i);
Self::exec_command(&mut zc, rerun.command, true);
} else {
i += 1;
}
}

// check for refresh due records with active queriers
// Refresh cache records with active queriers
let mut query_count = 0;
for (ty_domain, _sender) in zc.queriers.iter() {
if let Some(instances) = zc.cache.map.get(ty_domain) {
for instance_ptr in instances.iter() {
if let Some(dns_ptr) = instance_ptr.any().downcast_ref::<DnsPointer>() {
// get the records of a particular service instance
if let Some(records) = zc.cache.map.get(&dns_ptr.alias) {
let now = current_time_millis();

for record in records.iter() {
let rec = record.get_record();
if !rec.is_expired(now) && rec.refresh_due(now) {
zc.send_query(&dns_ptr.alias, TYPE_ANY);
break; // for one instance, only query once
}
}
}
}
}
for instance in zc.cache.refresh_due(ty_domain).iter() {
zc.send_query(instance, TYPE_ANY);
query_count += 1;
}
}
zc.increase_counter(Counter::CacheRefreshQuery, query_count);

// check and evict expired records in our cache
let now = current_time_millis();
Expand All @@ -424,6 +457,8 @@ impl ServiceDaemon {
}

/// The entry point that executes all commands received by the daemon.
///
/// `repeating`: whether this is a retransmission.
fn exec_command(zc: &mut Zeroconf, command: Command, repeating: bool) {
match command {
Command::Browse(ty, next_delay, listener) => {
Expand All @@ -434,27 +469,34 @@ impl ServiceDaemon {
if !repeating {
zc.add_querier(ty.clone(), listener.clone());
// if we already have the records in our cache, just send them
zc.find_and_send_instances(&ty, listener.clone());
zc.query_cache(&ty, listener.clone());
}

zc.send_query(&ty, TYPE_PTR);
zc.increase_counter(Counter::Browse, 1);

let next_time = current_time_millis() + (next_delay * 1000) as u64;
let max_delay = 60 * 60;
let delay = cmp::min(next_delay * 2, max_delay);
zc.retransmissions
.insert(next_time, Command::Browse(ty, delay, listener));
zc.retransmissions.push(ReRun {
next_time,
command: Command::Browse(ty, delay, listener),
});
}

Command::Register(service_info) => {
debug!("register service {:?}", &service_info);
zc.register_service(service_info);
zc.increase_counter(Counter::Register, 1);
}

Command::Announce(fullname) => {
Command::RegisterResend(fullname) => {
debug!("announce service: {}", &fullname);
match zc.my_services.get(&fullname) {
Some(info) => zc.broadcast_service(info),
Some(info) => {
zc.broadcast_service(info);
zc.increase_counter(Counter::RegisterResend, 1);
}
None => debug!("announce: cannot find such service {}", &fullname),
}
}
Expand All @@ -468,11 +510,14 @@ impl ServiceDaemon {
}
Some((_k, info)) => {
let packet = zc.unregister_service(&info);
zc.increase_counter(Counter::Unregister, 1);
// repeat for one time just in case some peers miss the message
if !repeating && !packet.is_empty() {
let next_time = current_time_millis() + 120;
zc.retransmissions
.insert(next_time, Command::SendPacket(packet));
zc.retransmissions.push(ReRun {
next_time,
command: Command::UnregisterResend(packet),
});
}
UnregisterStatus::OK
}
Expand All @@ -482,9 +527,10 @@ impl ServiceDaemon {
}
}

Command::SendPacket(packet) => {
Command::UnregisterResend(packet) => {
debug!("Send a packet length of {}", packet.len());
zc.send_packet(&packet[..], &zc.broadcast_addr);
zc.increase_counter(Counter::UnregisterResend, 1);
}

Command::StopBrowse(ty_domain) => match zc.queriers.remove_entry(&ty_domain) {
Expand All @@ -495,6 +541,11 @@ impl ServiceDaemon {
},
},

Command::GetMetrics(resp_s) => match resp_s.send(zc.counters.clone()) {
Ok(()) => debug!("Sent metrics to the client"),
Err(e) => error!("Failed to send metrics: {}", e),
},

_ => {
error!("unexpected command: {:?}", &command);
}
Expand Down Expand Up @@ -541,6 +592,11 @@ fn current_time_millis() -> u64 {

type DnsRecordBox = Box<dyn DnsRecordExt + Send>;

struct ReRun {
next_time: u64,
command: Command,
}

/// A struct holding the state. It was inspired by `zeroconf` package in Python.
struct Zeroconf {
/// One socket to receive all mDNS packets incoming, regardless interface.
Expand All @@ -567,8 +623,10 @@ struct Zeroconf {
/// Active queriers interested instances
instances_to_resolve: HashMap<String, ServiceInfo>,

/// All repeating transmissions sorted by their "next_time"
retransmissions: BTreeMap<u64, Command>, // <next_time, command>
/// All repeating transmissions.
retransmissions: Vec<ReRun>,

counters: Metrics,
}

impl Zeroconf {
Expand Down Expand Up @@ -598,7 +656,8 @@ impl Zeroconf {
cache: DnsCache::new(),
queriers: HashMap::new(),
instances_to_resolve: HashMap::new(),
retransmissions: BTreeMap::new(),
retransmissions: Vec::new(),
counters: HashMap::new(),
})
}

Expand All @@ -622,12 +681,15 @@ impl Zeroconf {
// ..The Multicast DNS responder MUST send at least two unsolicited
// responses, one second apart.
let next_time = current_time_millis() + 1000;
self.retransmissions
.insert(next_time, Command::Announce(info.fullname.to_lowercase()));

// The key has to be lower case letter as DNS record name is case insensitive.
// The info will have the original name.
self.my_services.insert(info.fullname.to_lowercase(), info);
let service_fullname = info.fullname.to_lowercase();
self.retransmissions.push(ReRun {
next_time,
command: Command::RegisterResend(service_fullname.clone()),
});
self.my_services.insert(service_fullname, info);
}

/// Send an unsolicited response for owned service
Expand Down Expand Up @@ -736,6 +798,9 @@ impl Zeroconf {
self.send(&out, &self.broadcast_addr)
}

/// Binds a channel `listener` to querying mDNS domain type `ty`.
///
/// If there is already a `listener`, it will be updated, i.e. overwritten.
fn add_querier(&mut self, ty: String, listener: Sender<ServiceEvent>) {
self.queriers.insert(ty, listener);
}
Expand Down Expand Up @@ -821,7 +886,9 @@ impl Zeroconf {
true
}

fn find_and_send_instances(&mut self, ty_domain: &str, sender: Sender<ServiceEvent>) {
/// Checks if `ty_domain` has records in the cache. If yes, sends the
/// cached records via `sender`.
fn query_cache(&mut self, ty_domain: &str, sender: Sender<ServiceEvent>) {
if let Some(records) = self.cache.get_records_by_name(ty_domain) {
for record in records.iter() {
if let Some(ptr) = record.any().downcast_ref::<DnsPointer>() {
Expand Down Expand Up @@ -1004,7 +1071,7 @@ impl Zeroconf {
}
}

fn handle_query(&self, msg: DnsIncoming, addr: &SockAddr) {
fn handle_query(&mut self, msg: DnsIncoming, addr: &SockAddr) {
debug!("handle_query from {}", &addr);
let mut out = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA);

Expand Down Expand Up @@ -1133,6 +1200,19 @@ impl Zeroconf {
if !out.answers.is_empty() {
out.id = msg.id;
self.send(&out, &self.broadcast_addr);

self.increase_counter(Counter::Respond, 1);
}
}

/// Increases the value of `counter` by `count`.
fn increase_counter(&mut self, counter: Counter, count: i64) {
let key = counter.to_string();
match self.counters.get_mut(&key) {
Some(v) => *v += count,
None => {
self.counters.insert(key, count);
}
}
}
}
Expand Down Expand Up @@ -1164,15 +1244,18 @@ enum Command {
/// Unregister a service
Unregister(String, Sender<UnregisterStatus>), // (fullname)

/// Announce a service to local network
Announce(String), // (fullname), only used for retransmission
/// Announce again a service to local network
RegisterResend(String), // (fullname)

/// Send a multicast packet out
SendPacket(Vec<u8>), // (packet content), only for retransmission
/// Resend unregister packet.
UnregisterResend(Vec<u8>), // (packet content)

/// Stop browsing a service type
StopBrowse(String), // (ty_domain)

/// Read the current values of the counters
GetMetrics(Sender<Metrics>),

Exit,
}

Expand Down Expand Up @@ -1248,6 +1331,45 @@ impl DnsCache {
});
}
}

/// Returns the list of full name of the instances for a `ty_domain`.
fn instance_names(&self, ty_domain: &str) -> Vec<String> {
let mut result = Vec::new();
if let Some(instances) = self.map.get(ty_domain) {
for instance_ptr in instances.iter() {
if let Some(dns_ptr) = instance_ptr.any().downcast_ref::<DnsPointer>() {
result.push(dns_ptr.alias.clone());
}
}
}
result
}

/// Returns the list of instance names that are due for refresh
/// for a `ty_domain`.
///
/// For these instances, their refresh time will be updated so that
/// they will not refresh again.
fn refresh_due(&mut self, ty_domain: &str) -> Vec<String> {
let now = current_time_millis();
let mut result = Vec::new();

for instance in self.instance_names(ty_domain).iter() {
if let Some(records) = self.map.get_mut(instance) {
for record in records.iter_mut() {
let rec = record.get_record_mut();
if !rec.is_expired(now) && rec.refresh_due(now) {
result.push(instance.clone());

// Only refresh a record once, until it expires and resets.
rec.refresh_no_more();
break; // for one instance, only query once
}
}
}
}
result
}
}

/// Complete info about a Service Instance.
Expand Down Expand Up @@ -1748,6 +1870,12 @@ impl DnsRecord {
now >= self.refresh
}

/// Updates the refresh time to be the same as the expire time so that
/// there is no more refresh for this record.
fn refresh_no_more(&mut self) {
self.refresh = get_expiration_time(self.created, self.ttl, 100);
}

/// Returns the remaining TTL in seconds
fn get_remaining_ttl(&self, now: u64) -> u32 {
let remaining_millis = get_expiration_time(self.created, self.ttl, 100) - now;
Expand Down
Loading

0 comments on commit c5a759e

Please sign in to comment.