Skip to content

Commit

Permalink
proto: Convert TokenPayload into enum
Browse files Browse the repository at this point in the history
As of this commit, it only has a single variant, which is Retry.
However, the next commit will add an additional variant. In addition
to pure refactors, a discriminant byte is used when encoding.
  • Loading branch information
gretchenfrage committed Jan 26, 2025
1 parent 9e1f77c commit 78bfa5b
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 60 deletions.
2 changes: 1 addition & 1 deletion quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ impl Endpoint {
// retried by the application layer.
let loc_cid = self.local_cid_generator.generate_cid();

let payload = TokenPayload {
let payload = TokenPayload::Retry {
address: incoming.addresses.remote,
orig_dst_cid: incoming.packet.header.dst_cid,
issued: server_config.time_source.now(),
Expand Down
157 changes: 98 additions & 59 deletions quinn-proto/src/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,27 @@ impl IncomingToken {
return Ok(unvalidated);
};

// Validate token
if retry.payload.address != remote_address {
return Err(InvalidRetryTokenError);
}
if retry.payload.issued + server_config.retry_token_lifetime
< server_config.time_source.now()
{
return Err(InvalidRetryTokenError);
// Validate token, then convert into Self
match retry.payload {
TokenPayload::Retry {
address,
orig_dst_cid,
issued,
} => {
if address != remote_address {
return Err(InvalidRetryTokenError);
}
if issued + server_config.retry_token_lifetime < server_config.time_source.now() {
return Err(InvalidRetryTokenError);
}

Ok(Self {
retry_src_cid: Some(header.dst_cid),
orig_dst_cid,
validated: true,
})
}
}

// Convert token into Self
Ok(Self {
retry_src_cid: Some(header.dst_cid),
orig_dst_cid: retry.payload.orig_dst_cid,
validated: true,
})
}
}

Expand Down Expand Up @@ -101,9 +106,18 @@ impl Token {
let mut buf = Vec::new();

// Encode payload
encode_addr(&mut buf, self.payload.address);
self.payload.orig_dst_cid.encode_long(&mut buf);
encode_unix_secs(&mut buf, self.payload.issued);
match self.payload {
TokenPayload::Retry {
address,
orig_dst_cid,
issued,
} => {
buf.put_u8(TokenType::Retry as u8);
encode_addr(&mut buf, address);
orig_dst_cid.encode_long(&mut buf);
encode_unix_secs(&mut buf, issued);
}
}

// Encrypt
let aead_key = key.aead_from_hkdf(&self.nonce.to_le_bytes());
Expand All @@ -129,34 +143,48 @@ impl Token {

// Decode payload
let mut reader = &data[..];
let address = decode_addr(&mut reader)?;
let orig_dst_cid = ConnectionId::decode_long(&mut reader)?;
let issued = decode_unix_secs(&mut reader)?;
let payload = match TokenType::from_byte((&mut reader).get::<u8>().ok()?)? {
TokenType::Retry => TokenPayload::Retry {
address: decode_addr(&mut reader)?,
orig_dst_cid: ConnectionId::decode_long(&mut reader)?,
issued: decode_unix_secs(&mut reader)?,
},
};

if !reader.is_empty() {
// Consider extra bytes a decoding error (it may be from an incompatible endpoint)
return None;
}

Some(Self {
nonce,
payload: TokenPayload {
address,
orig_dst_cid,
issued,
},
})
Some(Self { nonce, payload })
}
}

/// Content of a [`Token`] that is encrypted from the client
pub(crate) struct TokenPayload {
/// The client's address
pub(crate) address: SocketAddr,
/// The destination connection ID set in the very first packet from the client
pub(crate) orig_dst_cid: ConnectionId,
/// The time at which this token was issued
pub(crate) issued: SystemTime,
pub(crate) enum TokenPayload {
/// Token originating from a Retry packet
Retry {
/// The client's address
address: SocketAddr,
/// The destination connection ID set in the very first packet from the client
orig_dst_cid: ConnectionId,
/// The time at which this token was issued
issued: SystemTime,
},
}

/// Variant tag for a [`TokenPayload`]
#[derive(Copy, Clone)]
#[repr(u8)]
enum TokenType {
Retry = 0,
}

impl TokenType {
fn from_byte(n: u8) -> Option<Self> {
use TokenType::*;
[Retry].into_iter().find(|ty| *ty as u8 == n)
}
}

fn encode_addr(buf: &mut Vec<u8>, address: SocketAddr) {
Expand Down Expand Up @@ -253,43 +281,54 @@ impl fmt::Display for ResetToken {

#[cfg(all(test, any(feature = "aws-lc-rs", feature = "ring")))]
mod test {
use super::*;
#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
use aws_lc_rs::hkdf;
use rand::prelude::*;
#[cfg(feature = "ring")]
use ring::hkdf;

fn token_round_trip(payload: TokenPayload) -> TokenPayload {
let rng = &mut rand::thread_rng();
let token = Token::new(payload, rng);
let mut master_key = [0; 64];
rng.fill_bytes(&mut master_key);
let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
let encoded = token.encode(&prk);
let decoded = Token::decode(&prk, &encoded).expect("token didn't decrypt / decode");
assert_eq!(token.nonce, decoded.nonce);
decoded.payload
}

#[test]
fn token_sanity() {
use super::*;
fn retry_token_sanity() {
use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator};
use crate::MAX_CID_SIZE;
use crate::{Duration, UNIX_EPOCH};

use rand::RngCore;
use std::net::Ipv6Addr;

let rng = &mut rand::thread_rng();

let mut master_key = [0; 64];
rng.fill_bytes(&mut master_key);

let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);

let address = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433);
let token = Token {
nonce: rng.gen(),
payload: TokenPayload {
address,
orig_dst_cid: RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(),
issued: UNIX_EPOCH + Duration::from_secs(42), // Fractional seconds would be lost
},
let address_1 = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433);
let orig_dst_cid_1 = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid();
let issued_1 = UNIX_EPOCH + Duration::from_secs(42); // Fractional seconds would be lost
let payload_1 = TokenPayload::Retry {
address: address_1,
orig_dst_cid: orig_dst_cid_1,
issued: issued_1,
};
#[allow(irrefutable_let_patterns)] // TEMPORARY until next commit
let TokenPayload::Retry {
address: address_2,
orig_dst_cid: orig_dst_cid_2,
issued: issued_2,
} = token_round_trip(payload_1)
else {
panic!("token decoded as wrong variant");
};
let encoded = token.encode(&prk);

let decoded = Token::decode(&prk, &encoded).expect("token didn't validate");
assert_eq!(token.payload.address, decoded.payload.address);
assert_eq!(token.payload.orig_dst_cid, decoded.payload.orig_dst_cid);
assert_eq!(token.payload.issued, decoded.payload.issued);
assert_eq!(address_1, address_2);
assert_eq!(orig_dst_cid_1, orig_dst_cid_2);
assert_eq!(issued_1, issued_2);
}

#[test]
Expand Down

0 comments on commit 78bfa5b

Please sign in to comment.