Skip to content

Commit

Permalink
Pass ConnectionId by value
Browse files Browse the repository at this point in the history
  • Loading branch information
gretchenfrage committed Dec 21, 2024
1 parent 16f83d1 commit 30b5b60
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 70 deletions.
8 changes: 4 additions & 4 deletions perf/src/noprotection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl NoProtectionServerConfig {

// forward all calls to inner except those related to packet encryption/decryption
impl crypto::Session for NoProtectionSession {
fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> crypto::Keys {
fn initial_keys(&self, dst_cid: ConnectionId, side: Side) -> crypto::Keys {
self.inner.initial_keys(dst_cid, side)
}

Expand Down Expand Up @@ -115,7 +115,7 @@ impl crypto::Session for NoProtectionSession {
Some(Self::wrap_packet_keys(keys))
}

fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool {
fn is_valid_retry(&self, orig_dst_cid: ConnectionId, header: &[u8], payload: &[u8]) -> bool {
self.inner.is_valid_retry(orig_dst_cid, header, payload)
}

Expand Down Expand Up @@ -149,12 +149,12 @@ impl crypto::ServerConfig for NoProtectionServerConfig {
fn initial_keys(
&self,
version: u32,
dst_cid: &ConnectionId,
dst_cid: ConnectionId,
) -> Result<crypto::Keys, crypto::UnsupportedVersion> {
self.inner.initial_keys(version, dst_cid)
}

fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
fn retry_tag(&self, version: u32, orig_dst_cid: ConnectionId, packet: &[u8]) -> [u8; 16] {
self.inner.retry_tag(version, orig_dst_cid, packet)
}

Expand Down
6 changes: 3 additions & 3 deletions quinn-proto/src/cid_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub trait ConnectionIdGenerator: Send + Sync {
/// Quickly determine whether `cid` could have been generated by this generator
///
/// False positives are permitted, but increase the cost of handling invalid packets.
fn validate(&self, _cid: &ConnectionId) -> Result<(), InvalidCid> {
fn validate(&self, _cid: ConnectionId) -> Result<(), InvalidCid> {
Ok(())
}

Expand Down Expand Up @@ -143,7 +143,7 @@ impl ConnectionIdGenerator for HashedConnectionIdGenerator {
ConnectionId::new(&bytes_arr)
}

fn validate(&self, cid: &ConnectionId) -> Result<(), InvalidCid> {
fn validate(&self, cid: ConnectionId) -> Result<(), InvalidCid> {
let (nonce, signature) = cid.split_at(NONCE_LEN);
let mut hasher = rustc_hash::FxHasher::default();
hasher.write_u64(self.key);
Expand Down Expand Up @@ -175,6 +175,6 @@ mod tests {
fn validate_keyed_cid() {
let mut generator = HashedConnectionIdGenerator::new();
let cid = generator.generate_cid();
generator.validate(&cid).unwrap();
generator.validate(cid).unwrap();
}
}
6 changes: 3 additions & 3 deletions quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl Connection {
let connection_side = ConnectionSide::from(side_args);
let side = connection_side.side();
let initial_space = PacketSpace {
crypto: Some(crypto.initial_keys(&init_cid, side)),
crypto: Some(crypto.initial_keys(init_cid, side)),
..PacketSpace::new(now)
};
let state = State::Handshake(state::Handshake {
Expand Down Expand Up @@ -2362,7 +2362,7 @@ impl Connection {
if self.total_authed_packets > 1
|| packet.payload.len() <= 16 // token + 16 byte tag
|| !self.crypto.is_valid_retry(
&self.rem_cids.active(),
self.rem_cids.active(),
&packet.header_data,
&packet.payload,
)
Expand Down Expand Up @@ -2391,7 +2391,7 @@ impl Connection {

self.discard_space(now, SpaceId::Initial); // Make sure we clean up after any retransmitted Initials
self.spaces[SpaceId::Initial] = PacketSpace {
crypto: Some(self.crypto.initial_keys(&rem_cid, self.side.side())),
crypto: Some(self.crypto.initial_keys(rem_cid, self.side.side())),
next_packet_number: self.spaces[SpaceId::Initial].next_packet_number,
crypto_offset: client_hello.len() as u64,
..PacketSpace::new(now)
Expand Down
13 changes: 5 additions & 8 deletions quinn-proto/src/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub mod rustls;
/// A cryptographic session (commonly TLS)
pub trait Session: Send + Sync + 'static {
/// Create the initial set of keys given the client's initial destination ConnectionId
fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys;
fn initial_keys(&self, dst_cid: ConnectionId, side: Side) -> Keys;

/// Get data negotiated during the handshake, if available
///
Expand Down Expand Up @@ -77,7 +77,7 @@ pub trait Session: Send + Sync + 'static {
fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn PacketKey>>>;

/// Verify the integrity of a retry packet
fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool;
fn is_valid_retry(&self, orig_dst_cid: ConnectionId, header: &[u8], payload: &[u8]) -> bool;

/// Fill `output` with `output.len()` bytes of keying material derived
/// from the [Session]'s secrets, using `label` and `context` for domain
Expand Down Expand Up @@ -123,16 +123,13 @@ pub trait ClientConfig: Send + Sync {
/// Server-side configuration for the crypto protocol
pub trait ServerConfig: Send + Sync {
/// Create the initial set of keys given the client's initial destination ConnectionId
fn initial_keys(
&self,
version: u32,
dst_cid: &ConnectionId,
) -> Result<Keys, UnsupportedVersion>;
fn initial_keys(&self, version: u32, dst_cid: ConnectionId)
-> Result<Keys, UnsupportedVersion>;

/// Generate the integrity tag for a retry packet
///
/// Never called if `initial_keys` rejected `version`.
fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16];
fn retry_tag(&self, version: u32, orig_dst_cid: ConnectionId, packet: &[u8]) -> [u8; 16];

/// Start a server session with this configuration
///
Expand Down
16 changes: 8 additions & 8 deletions quinn-proto/src/crypto/rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl TlsSession {
}

impl crypto::Session for TlsSession {
fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys {
fn initial_keys(&self, dst_cid: ConnectionId, side: Side) -> Keys {
initial_keys(self.version, dst_cid, side, &self.suite)
}

Expand Down Expand Up @@ -162,7 +162,7 @@ impl crypto::Session for TlsSession {
})
}

fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool {
fn is_valid_retry(&self, orig_dst_cid: ConnectionId, header: &[u8], payload: &[u8]) -> bool {
let tag_start = match payload.len().checked_sub(16) {
Some(x) => x,
None => return false,
Expand All @@ -171,7 +171,7 @@ impl crypto::Session for TlsSession {
let mut pseudo_packet =
Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
pseudo_packet.push(orig_dst_cid.len() as u8);
pseudo_packet.extend_from_slice(orig_dst_cid);
pseudo_packet.extend_from_slice(&orig_dst_cid);
pseudo_packet.extend_from_slice(header);
let tag_start = tag_start + pseudo_packet.len();
pseudo_packet.extend_from_slice(payload);
Expand Down Expand Up @@ -501,13 +501,13 @@ impl crypto::ServerConfig for QuicServerConfig {
fn initial_keys(
&self,
version: u32,
dst_cid: &ConnectionId,
dst_cid: ConnectionId,
) -> Result<Keys, UnsupportedVersion> {
let version = interpret_version(version)?;
Ok(initial_keys(version, dst_cid, Side::Server, &self.initial))
}

fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
fn retry_tag(&self, version: u32, orig_dst_cid: ConnectionId, packet: &[u8]) -> [u8; 16] {
// Safe: `start_session()` is never called if `initial_keys()` rejected `version`
let version = interpret_version(version).unwrap();
let (nonce, key) = match version {
Expand All @@ -518,7 +518,7 @@ impl crypto::ServerConfig for QuicServerConfig {

let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
pseudo_packet.push(orig_dst_cid.len() as u8);
pseudo_packet.extend_from_slice(orig_dst_cid);
pseudo_packet.extend_from_slice(&orig_dst_cid);
pseudo_packet.extend_from_slice(packet);

let nonce = aead::Nonce::assume_unique_for_key(nonce);
Expand Down Expand Up @@ -564,11 +564,11 @@ fn to_vec(params: &TransportParameters) -> Vec<u8> {

pub(crate) fn initial_keys(
version: Version,
dst_cid: &ConnectionId,
dst_cid: ConnectionId,
side: Side,
suite: &Suite,
) -> Keys {
let keys = suite.keys(dst_cid, side.into(), version);
let keys = suite.keys(&dst_cid, side.into(), version);
Keys {
header: KeyPair {
local: Box::new(keys.local.header),
Expand Down
40 changes: 20 additions & 20 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl Endpoint {
RetireConnectionId(now, seq, allow_more_cids) => {
if let Some(cid) = self.connections[ch].loc_cids.remove(&seq) {
trace!("peer retired CID {}: {}", seq, cid);
self.index.retire(&cid);
self.index.retire(cid);
if allow_more_cids {
return Some(self.send_new_identifiers(now, ch, 1));
}
Expand Down Expand Up @@ -276,7 +276,7 @@ impl Endpoint {
header.version,
addresses,
&crypto,
&header.src_cid,
header.src_cid,
reason,
buf,
)));
Expand Down Expand Up @@ -329,7 +329,7 @@ impl Endpoint {
now: Instant,
inciting_dgram_len: usize,
addresses: FourTuple,
dst_cid: &ConnectionId,
dst_cid: ConnectionId,
buf: &mut Vec<u8>,
) -> Option<Transmit> {
if self
Expand Down Expand Up @@ -451,7 +451,7 @@ impl Endpoint {
ids.push(IssuedCid {
sequence,
id,
reset_token: ResetToken::new(&*self.config.reset_key, &id),
reset_token: ResetToken::new(&*self.config.reset_key, id),
});
}
ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids, now))
Expand Down Expand Up @@ -502,7 +502,7 @@ impl Endpoint {
header.version,
addresses,
&crypto,
&header.src_cid,
header.src_cid,
TransportError::INVALID_TOKEN(""),
buf,
)));
Expand Down Expand Up @@ -577,7 +577,7 @@ impl Endpoint {
version,
incoming.addresses,
&incoming.crypto,
&src_cid,
src_cid,
TransportError::CONNECTION_REFUSED(""),
buf,
)),
Expand Down Expand Up @@ -613,7 +613,7 @@ impl Endpoint {
Some(&server_config),
&mut self.rng,
);
params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, &loc_cid));
params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, loc_cid));
params.original_dst_cid = Some(incoming.token.orig_dst_cid);
params.retry_src_cid = incoming.token.retry_src_cid;
let mut pref_addr_cid = None;
Expand All @@ -626,7 +626,7 @@ impl Endpoint {
address_v4: server_config.preferred_address_v4,
address_v6: server_config.preferred_address_v6,
connection_id: cid,
stateless_reset_token: ResetToken::new(&*self.config.reset_key, &cid),
stateless_reset_token: ResetToken::new(&*self.config.reset_key, cid),
});
}

Expand Down Expand Up @@ -675,7 +675,7 @@ impl Endpoint {
version,
incoming.addresses,
&incoming.crypto,
&src_cid,
src_cid,
e.clone(),
buf,
)),
Expand Down Expand Up @@ -725,7 +725,7 @@ impl Endpoint {
incoming.packet.header.version,
incoming.addresses,
&incoming.crypto,
&incoming.packet.header.src_cid,
incoming.packet.header.src_cid,
TransportError::CONNECTION_REFUSED(""),
buf,
)
Expand Down Expand Up @@ -759,7 +759,7 @@ impl Endpoint {
.encode(
&*server_config.token_key,
&incoming.addresses.remote,
&loc_cid,
loc_cid,
);

let header = Header::Retry {
Expand All @@ -772,7 +772,7 @@ impl Endpoint {
buf.put_slice(&token);
buf.extend_from_slice(&server_config.crypto.retry_tag(
incoming.packet.header.version,
&incoming.packet.header.dst_cid,
incoming.packet.header.dst_cid,
buf,
));
encode.finish(buf, &*incoming.crypto.header.local, None);
Expand Down Expand Up @@ -868,7 +868,7 @@ impl Endpoint {
version: u32,
addresses: FourTuple,
crypto: &Keys,
remote_id: &ConnectionId,
remote_id: ConnectionId,
reason: TransportError,
buf: &mut Vec<u8>,
) -> Transmit {
Expand All @@ -878,7 +878,7 @@ impl Endpoint {
let local_id = self.local_cid_generator.generate_cid();
let number = PacketNumber::U8(0);
let header = Header::Initial(InitialHeader {
dst_cid: *remote_id,
dst_cid: remote_id,
src_cid: local_id,
number,
token: Bytes::new(),
Expand Down Expand Up @@ -1066,8 +1066,8 @@ impl ConnectionIndex {
}

/// Discard a connection ID
fn retire(&mut self, dst_cid: &ConnectionId) {
self.connection_ids.remove(dst_cid);
fn retire(&mut self, dst_cid: ConnectionId) {
self.connection_ids.remove(&dst_cid);
}

/// Remove all references to a connection
Expand All @@ -1089,12 +1089,12 @@ impl ConnectionIndex {
/// Find the existing connection that `datagram` should be routed to, if any
fn get(&self, addresses: &FourTuple, datagram: &PartialDecode) -> Option<RouteDatagramTo> {
if datagram.dst_cid().len() != 0 {
if let Some(&ch) = self.connection_ids.get(datagram.dst_cid()) {
if let Some(&ch) = self.connection_ids.get(&datagram.dst_cid()) {
return Some(RouteDatagramTo::Connection(ch));
}
}
if datagram.is_initial() || datagram.is_0rtt() {
if let Some(&ch) = self.connection_ids_initial.get(datagram.dst_cid()) {
if let Some(&ch) = self.connection_ids_initial.get(&datagram.dst_cid()) {
return Some(ch);
}
}
Expand Down Expand Up @@ -1203,8 +1203,8 @@ impl Incoming {
}

/// The original destination connection ID sent by the client
pub fn orig_dst_cid(&self) -> &ConnectionId {
&self.token.orig_dst_cid
pub fn orig_dst_cid(&self) -> ConnectionId {
self.token.orig_dst_cid
}
}

Expand Down
Loading

0 comments on commit 30b5b60

Please sign in to comment.