Skip to content

Commit

Permalink
add refresh for SRV records and send ServiceRemoved for expired SRV r…
Browse files Browse the repository at this point in the history
…ecords

- Added logic to handle TTL = 0 in responses.
- Only apply cache-flush bit if a new record (i.e. rdata changed).
  • Loading branch information
keepsimple1 committed May 6, 2024
1 parent 06e2cf7 commit fb0869a
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 106 deletions.
87 changes: 82 additions & 5 deletions src/dns_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,14 @@ pub struct DnsRecord {
impl DnsRecord {
fn new(name: &str, ty: u16, class: u16, ttl: u32) -> Self {
let created = current_time_millis();

// From RFC 6762 section 5.2:
// "... The querier should plan to issue a query at 80% of the record
// lifetime, and then if no answer is received, at 85%, 90%, and 95%."
let refresh = get_expiration_time(created, ttl, 80);

let expires = get_expiration_time(created, ttl, 100);

Self {
entry: DnsEntry::new(name.to_string(), ty, class),
ttl,
Expand Down Expand Up @@ -136,6 +142,36 @@ impl DnsRecord {
self.refresh = get_expiration_time(self.created, self.ttl, 100);
}

/// Returns if this record is due for refresh. If yes, `refresh` time is updated.
pub(crate) fn refresh_maybe(&mut self, now: u64) -> bool {
if self.is_expired(now) || !self.refresh_due(now) {
return false;
}

debug!(
"{} qtype {} is due to refresh",
&self.entry.name, self.entry.ty
);

// From RFC 6762 section 5.2:
// "... The querier should plan to issue a query at 80% of the record
// lifetime, and then if no answer is received, at 85%, 90%, and 95%."
//
// If the answer is received in time, 'refresh' will be reset outside
// this function, back to 80% of the new TTL.
if self.refresh == get_expiration_time(self.created, self.ttl, 80) {
self.refresh = get_expiration_time(self.created, self.ttl, 85);
} else if self.refresh == get_expiration_time(self.created, self.ttl, 85) {
self.refresh = get_expiration_time(self.created, self.ttl, 90);
} else if self.refresh == get_expiration_time(self.created, self.ttl, 90) {
self.refresh = get_expiration_time(self.created, self.ttl, 95);
} else {
self.refresh_no_more();
}

true
}

/// Returns the remaining TTL in seconds
fn get_remaining_ttl(&self, now: u64) -> u32 {
let remaining_millis = get_expiration_time(self.created, self.ttl, 100) - now;
Expand Down Expand Up @@ -190,6 +226,8 @@ pub trait DnsRecordExt: fmt::Debug {
self.get_record().entry.ty
}

/// Resets TTL using `other` record.
/// `self.refresh` and `self.expires` are also reset.
fn reset_ttl(&mut self, other: &dyn DnsRecordExt) {
self.get_record_mut().reset_ttl(other.get_record());
}
Expand Down Expand Up @@ -1084,7 +1122,16 @@ impl DnsIncoming {

let ty = u16_from_be_slice(&slice[..2]);
let class = u16_from_be_slice(&slice[2..4]);
let ttl = u32_from_be_slice(&slice[4..8]);
let mut ttl = u32_from_be_slice(&slice[4..8]);
if ttl == 0 && self.is_response() {
// RFC 6762 section 10.1:
// "...Queriers receiving a Multicast DNS response with a TTL of zero SHOULD
// NOT immediately delete the record from the cache, but instead record
// a TTL of 1 and then delete the record one second later."
// See https://datatracker.ietf.org/doc/html/rfc6762#section-10.1

ttl = 1;
}
let length = u16_from_be_slice(&slice[8..10]) as usize;
self.offset += 10;
let next_offset = self.offset + length;
Expand Down Expand Up @@ -1354,19 +1401,22 @@ const fn u32_from_be_slice(s: &[u8]) -> u32 {
u32::from_be_bytes(u8_array)
}

/// Returns the time in millis at which this record will have expired
/// Returns the UNIX time in millis at which this record will have expired
/// by a certain percentage.
const fn get_expiration_time(created: u64, ttl: u32, percent: u32) -> u64 {
// 'created' is in millis, 'ttl' is in seconds, hence:
// ttl * 1000 * (percent / 100) => ttl * percent * 10
created + (ttl * percent * 10) as u64
}

#[cfg(test)]
mod tests {
use crate::dns_parser::{TYPE_A, TYPE_AAAA};
use crate::dns_parser::get_expiration_time;

use super::{
DnsIncoming, DnsNSec, DnsOutgoing, DnsSrv, CLASS_CACHE_FLUSH, CLASS_IN, FLAGS_QR_QUERY,
FLAGS_QR_RESPONSE, TYPE_PTR,
current_time_millis, DnsIncoming, DnsNSec, DnsOutgoing, DnsRecordExt, DnsSrv,
CLASS_CACHE_FLUSH, CLASS_IN, FLAGS_QR_QUERY, FLAGS_QR_RESPONSE, TYPE_A, TYPE_AAAA,
TYPE_PTR,
};

#[test]
Expand Down Expand Up @@ -1453,4 +1503,31 @@ mod tests {
assert_eq!(absent_types[0], TYPE_A);
assert_eq!(absent_types[1], TYPE_AAAA);
}

#[test]
fn test_refresh_maybe() {
let name = "test_refresh._udp.local.";
let ttl = 2;
let hostname = "instance1.local.";
let mut srv = DnsSrv::new(name, CLASS_IN, ttl, 0, 0, 0, hostname.to_string());

// refresh is not due yet.
let now = current_time_millis();
let refreshed = srv.get_record_mut().refresh_maybe(now);
assert!(!refreshed);

// sleep for 80 percent of TTL in millis to reach "refresh" time.
let sleep_in_mills = (ttl * 80 * 10) as u64;
std::thread::sleep(std::time::Duration::from_millis(sleep_in_mills));

// refresh is due.
let now = current_time_millis();
let refreshed = srv.get_record_mut().refresh_maybe(now);
assert!(refreshed);

// refresh time is updated.
let dns_record = srv.get_record();
let new_refresh = get_expiration_time(dns_record.get_created(), dns_record.ttl, 85);
assert_eq!(new_refresh, dns_record.get_refresh_time());
}
}
Loading

0 comments on commit fb0869a

Please sign in to comment.