diff --git a/quinn-proto/src/tests/util.rs b/quinn-proto/src/tests/util.rs index 03522bc97..eb6ad4f34 100644 --- a/quinn-proto/src/tests/util.rs +++ b/quinn-proto/src/tests/util.rs @@ -1,6 +1,6 @@ use std::{ cmp, - collections::{HashMap, VecDeque}, + collections::{HashMap, HashSet, VecDeque}, env, io::{self, Write}, mem, @@ -560,14 +560,24 @@ impl Write for TestWriter { } pub(super) fn server_config() -> ServerConfig { - ServerConfig::with_crypto(Arc::new(server_crypto())) + let mut config = ServerConfig::with_crypto(Arc::new(server_crypto())); + config + .validation_token + .sent(2) + .log(Arc::new(SimpleTokenLog::default())); + config } pub(super) fn server_config_with_cert( cert: CertificateDer<'static>, key: PrivateKeyDer<'static>, ) -> ServerConfig { - ServerConfig::with_crypto(Arc::new(server_crypto_with_cert(cert, key))) + let mut config = ServerConfig::with_crypto(Arc::new(server_crypto_with_cert(cert, key))); + config + .validation_token + .sent(2) + .log(Arc::new(SimpleTokenLog::default())); + config } pub(super) fn server_crypto() -> QuicServerConfig { @@ -605,7 +615,9 @@ fn server_crypto_inner( } pub(super) fn client_config() -> ClientConfig { - ClientConfig::new(Arc::new(client_crypto())) + let mut config = ClientConfig::new(Arc::new(client_crypto())); + config.token_store(Arc::new(SimpleTokenStore::default())); + config } pub(super) fn client_config_with_deterministic_pns() -> ClientConfig { @@ -713,3 +725,43 @@ lazy_static! { pub(crate) static ref CERTIFIED_KEY: rcgen::CertifiedKey = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); } + +#[derive(Default)] +struct SimpleTokenLog(Mutex>); + +impl TokenLog for SimpleTokenLog { + fn check_and_insert( + &self, + nonce: u128, + _issued: SystemTime, + _lifetime: Duration, + ) -> Result<(), TokenReuseError> { + if self.0.lock().unwrap().insert(nonce) { + Ok(()) + } else { + Err(TokenReuseError) + } + } +} + +#[derive(Default)] +struct SimpleTokenStore(Mutex>>); + +impl TokenStore for SimpleTokenStore { + fn insert(&self, server_name: &str, token: Bytes) { + self.0 + .lock() + .unwrap() + .entry(server_name.into()) + .or_default() + .push_back(token); + } + + fn take(&self, server_name: &str) -> Option { + self.0 + .lock() + .unwrap() + .get_mut(server_name) + .and_then(|queue| queue.pop_front()) + } +}