Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
folkertdev committed Nov 13, 2023
1 parent 2048648 commit d248da0
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 15 deletions.
1 change: 1 addition & 0 deletions ntp-proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rust-version.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
default = ["nts-pool"]
__internal-fuzz = ["arbitrary", "__internal-api"]
__internal-test = ["__internal-api"]
__internal-api = []
Expand Down
208 changes: 193 additions & 15 deletions ntp-proto/src/nts_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,34 @@ impl NtsRecord {
]
}

pub fn client_key_exchange_records_fixed(
c2s: Vec<u8>,
s2c: Vec<u8>,
) -> [NtsRecord; if cfg!(feature = "ntpv5") { 5 } else { 4 }] {
[
#[cfg(feature = "ntpv5")]
NtsRecord::DraftId {
data: crate::packet::v5::DRAFT_VERSION.as_bytes().into(),
},
NtsRecord::NextProtocol {
protocol_ids: vec![
#[cfg(feature = "ntpv5")]
0x8001,
0,
],
},
NtsRecord::AeadAlgorithm {
critical: false,
algorithm_ids: AeadAlgorithm::IN_ORDER_OF_PREFERENCE
.iter()
.map(|algorithm| *algorithm as u16)
.collect(),
},
NtsRecord::FixedKeyRequest { c2s, s2c },
NtsRecord::EndOfMessage,
]
}

fn server_key_exchange_records(
protocol: ProtocolId,
algorithm: AeadAlgorithm,
Expand Down Expand Up @@ -575,6 +603,8 @@ pub enum KeyExchangeError {
NoValidProtocol,
#[error("No encryption algorithm supported by both us and server")]
NoValidAlgorithm,
#[error("The length of a fixed key does not match the algorithm used")]
InvalidFixedKeyLength,
#[error("Missing cookies")]
NoCookies,
#[error("{0}")]
Expand Down Expand Up @@ -1068,10 +1098,32 @@ struct KeyExchangeServerDecoder {

#[derive(Debug, PartialEq, Eq)]
struct ServerKeyExchangeData {
algorithm: AeadAlgorithm,
key_method: KeyMethod,
protocol: ProtocolId,
}

#[derive(Debug, PartialEq, Eq)]
enum KeyMethod {
/// Perform key extraction to acquire the c2s and s2c keys
KeyExtraction { algorithm: AeadAlgorithm },
/// Use these fixed keys
Fixed {
algorithm: AeadAlgorithm,
c2s: Vec<u8>,
s2c: Vec<u8>,
},
}

impl KeyMethod {
#[cfg(test)]
fn algorithm(&self) -> AeadAlgorithm {
match self {
KeyMethod::KeyExtraction { algorithm } => *algorithm,
KeyMethod::Fixed { algorithm, .. } => *algorithm,
}
}
}

impl KeyExchangeServerDecoder {
pub fn step_with_slice(
mut self,
Expand Down Expand Up @@ -1102,8 +1154,27 @@ impl KeyExchangeServerDecoder {

match record {
EndOfMessage => {
let key_method = {
#[cfg(feature = "nts-pool")]
let fixed_key_request = state.fixed_key_request;

#[cfg(not(feature = "nts-pool"))]
let fixed_key_request = None;

let algorithm = state.algorithm;

match fixed_key_request {
None => KeyMethod::KeyExtraction { algorithm },
Some((c2s, s2c)) => KeyMethod::Fixed {
algorithm,
c2s,
s2c,
},
}
};

let result = ServerKeyExchangeData {
algorithm: state.algorithm,
key_method,
protocol: state.protocol,
};

Expand Down Expand Up @@ -1308,18 +1379,81 @@ impl KeyExchangeServer {
}
ControlFlow::Break(Ok(result)) => {
self.decoder = None;
let algorithm = result.algorithm;
let key_method = result.key_method;
let protocol = result.protocol;

tracing::debug!(?algorithm, "selected AEAD algorithm");
let algorithm = match key_method {
KeyMethod::KeyExtraction { algorithm } => {
tracing::debug!(
?algorithm,
"selected AEAD algorithm for key extraction"
);
algorithm
}
KeyMethod::Fixed { algorithm, .. } => {
tracing::debug!(
?algorithm,
"using fixed keys with AEAD algorithm"
);
algorithm
}
};

let keys = match algorithm
.extract_nts_keys(protocol, &self.tls_connection)
{
Ok(keys) => keys,
Err(e) => {
return ControlFlow::Break(Err(KeyExchangeError::Tls(e)))
let keys = match key_method {
KeyMethod::KeyExtraction { algorithm } => {
match algorithm
.extract_nts_keys(protocol, &self.tls_connection)
{
Ok(keys) => keys,
Err(e) => {
return ControlFlow::Break(Err(
KeyExchangeError::Tls(e),
))
}
}
}
KeyMethod::Fixed {
algorithm,
c2s,
s2c,
} => match algorithm {
AeadAlgorithm::AeadAesSivCmac256 => {
const KEY_WIDTH: usize =
std::mem::size_of::<aead::Key<Aes128SivAead>>();

if c2s.len() != KEY_WIDTH || s2c.len() != KEY_WIDTH {
return ControlFlow::Break(Err(
KeyExchangeError::InvalidFixedKeyLength,
));
}

let c2s = *aead::Key::<Aes128SivAead>::from_slice(&c2s);
let s2c = *aead::Key::<Aes128SivAead>::from_slice(&s2c);

let c2s = Box::new(AesSivCmac256::new(c2s));
let s2c = Box::new(AesSivCmac256::new(s2c));

NtsKeys { c2s, s2c }
}
AeadAlgorithm::AeadAesSivCmac512 => {
const KEY_WIDTH: usize =
std::mem::size_of::<aead::Key<Aes256SivAead>>();

if c2s.len() != KEY_WIDTH || s2c.len() != KEY_WIDTH {
return ControlFlow::Break(Err(
KeyExchangeError::InvalidFixedKeyLength,
));
}

let c2s = *aead::Key::<Aes256SivAead>::from_slice(&c2s);
let s2c = *aead::Key::<Aes256SivAead>::from_slice(&s2c);

let c2s = Box::new(AesSivCmac512::new(c2s));
let s2c = Box::new(AesSivCmac512::new(s2c));

NtsKeys { c2s, s2c }
}
},
};

return match self.send_response(protocol, algorithm, keys) {
Expand Down Expand Up @@ -1979,7 +2113,10 @@ mod test {
fn server_decoder_finds_algorithm() {
let result = server_roundtrip(&NtsRecord::client_key_exchange_records()).unwrap();

assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512);
assert_eq!(
result.key_method.algorithm(),
AeadAlgorithm::AeadAesSivCmac512
);
}

#[test]
Expand All @@ -1993,7 +2130,10 @@ mod test {
);

let result = server_roundtrip(&records).unwrap();
assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512);
assert_eq!(
result.key_method.algorithm(),
AeadAlgorithm::AeadAesSivCmac512
);
}

#[test]
Expand All @@ -2016,7 +2156,10 @@ mod test {
);

let result = server_roundtrip(&records).unwrap();
assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512);
assert_eq!(
result.key_method.algorithm(),
AeadAlgorithm::AeadAesSivCmac512
);
}

#[test]
Expand All @@ -2025,7 +2168,10 @@ mod test {
records.insert(0, NtsRecord::Warning { warningcode: 42 });

let result = server_roundtrip(&records).unwrap();
assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512);
assert_eq!(
result.key_method.algorithm(),
AeadAlgorithm::AeadAesSivCmac512
);
}

#[test]
Expand All @@ -2041,7 +2187,10 @@ mod test {
);

let result = server_roundtrip(&records).unwrap();
assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512);
assert_eq!(
result.key_method.algorithm(),
AeadAlgorithm::AeadAesSivCmac512
);
}

#[test]
Expand Down Expand Up @@ -2294,6 +2443,35 @@ mod test {
assert_eq!(result.protocol_version, ProtocolVersion::V5);
}

#[test]
fn test_keyexchange_roundtrip_fixed() {
let (mut client, server) = client_server_pair();

let c2s: Vec<_> = (0..).take(64).collect();
let s2c: Vec<_> = (0..).skip(64).take(64).collect();

let mut buffer = Vec::with_capacity(1024);
for record in NtsRecord::client_key_exchange_records_fixed(c2s.clone(), s2c.clone()) {
record.write(&mut buffer).unwrap();
}
client.tls_connection.writer().write_all(&buffer).unwrap();

let keyset = server.keyset.clone();
let mut result = keyexchange_loop(client, server).unwrap();

assert_eq!(&result.remote, "localhost");
assert_eq!(result.port, 123);

let cookie = result.nts.get_cookie().unwrap();
let cookie = keyset.decode_cookie(&cookie).unwrap();

assert_eq!(cookie.c2s.key_bytes(), c2s);
assert_eq!(cookie.s2c.key_bytes(), s2c);

#[cfg(feature = "ntpv5")]
assert_eq!(result.protocol_version, ProtocolVersion::V5);
}

#[test]
fn test_keyexchange_invalid_input() {
let mut buffer = Vec::with_capacity(1024);
Expand Down

0 comments on commit d248da0

Please sign in to comment.