diff --git a/bindings/matrix-sdk-crypto-ffi/src/dehydrated_devices.rs b/bindings/matrix-sdk-crypto-ffi/src/dehydrated_devices.rs index c6550c2b2ea..548fb83befb 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/dehydrated_devices.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/dehydrated_devices.rs @@ -54,12 +54,13 @@ impl Drop for DehydratedDevices { #[uniffi::export] impl DehydratedDevices { - pub fn create(&self) -> Arc { - DehydratedDevice { - inner: ManuallyDrop::new(self.inner.create()), + pub fn create(&self) -> Result, DehydrationError> { + let inner = self.runtime.block_on(self.inner.create())?; + + Ok(Arc::new(DehydratedDevice { + inner: ManuallyDrop::new(inner), runtime: self.runtime.to_owned(), - } - .into() + })) } pub fn rehydrate( diff --git a/crates/matrix-sdk-crypto/src/dehydrated_devices.rs b/crates/matrix-sdk-crypto/src/dehydrated_devices.rs index f4ba22fbf09..d7f4f90b165 100644 --- a/crates/matrix-sdk-crypto/src/dehydrated_devices.rs +++ b/crates/matrix-sdk-crypto/src/dehydrated_devices.rs @@ -91,7 +91,7 @@ pub struct DehydratedDevices { impl DehydratedDevices { /// Create a new [`DehydratedDevice`] which can be uploaded to the server. - pub fn create(&self) -> DehydratedDevice { + pub async fn create(&self) -> Result { let user_id = self.inner.user_id(); let user_identity = self.inner.store().private_identity(); @@ -104,9 +104,11 @@ impl DehydratedDevices { store.clone(), ); - let store = Store::new(account, user_identity, store, verification_machine); + let store = + Store::new(account.static_data().clone(), user_identity, store, verification_machine); + store.save_pending_changes(crate::store::PendingChanges { account: Some(account) }).await?; - DehydratedDevice { store } + Ok(DehydratedDevice { store }) } /// Rehydrate the dehydrated device. @@ -289,7 +291,7 @@ impl DehydratedDevice { /// let pickle_key = [0u8; 32]; /// /// // Create the dehydrated device. - /// let device = machine.dehydrated_devices().create(); + /// let device = machine.dehydrated_devices().create().await?; /// /// // Create the request that should upload the device. /// let request = device @@ -316,9 +318,9 @@ impl DehydratedDevice { let mut transaction = self.store.transaction().await?; let account = transaction.account().await?; - account.generate_fallback_key_helper().await; + account.generate_fallback_key_helper(); - let (device_keys, one_time_keys, fallback_keys) = account.keys_for_upload().await; + let (device_keys, one_time_keys, fallback_keys) = account.keys_for_upload(); let mut device_keys = device_keys .expect("We should always try to upload device keys for a dehydrated device."); @@ -329,7 +331,7 @@ impl DehydratedDevice { let pickle_key = expand_pickle_key(pickle_key, &self.store.static_account().device_id); let device_id = self.store.static_account().device_id.clone(); - let device_data = account.dehydrate(&pickle_key).await; + let device_data = account.dehydrate(&pickle_key); let initial_device_display_name = Some(initial_device_display_name); transaction.commit().await?; @@ -456,10 +458,10 @@ mod tests { } #[async_test] - async fn dehydrated_device_creation() { + async fn test_dehydrated_device_creation() { let olm_machine = get_olm_machine().await; - let dehydrated_device = olm_machine.dehydrated_devices().create(); + let dehydrated_device = olm_machine.dehydrated_devices().create().await.unwrap(); let request = dehydrated_device .keys_for_upload("Foo".to_owned(), PICKLE_KEY) @@ -482,7 +484,7 @@ mod tests { let room_id = room_id!("!test:example.org"); let alice = get_olm_machine().await; - let dehydrated_device = alice.dehydrated_devices().create(); + let dehydrated_device = alice.dehydrated_devices().create().await.unwrap(); let mut request = dehydrated_device .keys_for_upload("Foo".to_owned(), PICKLE_KEY) diff --git a/crates/matrix-sdk-crypto/src/gossiping/machine.rs b/crates/matrix-sdk-crypto/src/gossiping/machine.rs index 2b0d2e2b633..cebe9416567 100644 --- a/crates/matrix-sdk-crypto/src/gossiping/machine.rs +++ b/crates/matrix-sdk-crypto/src/gossiping/machine.rs @@ -46,7 +46,7 @@ use crate::{ olm::{InboundGroupSession, Session}, requests::{OutgoingRequest, ToDeviceRequest}, session_manager::GroupSessionCache, - store::{Changes, CryptoStoreError, SecretImportError, Store}, + store::{Changes, CryptoStoreError, SecretImportError, Store, StoreCache}, types::events::{ forwarded_room_key::ForwardedRoomKeyContent, olm_v1::{DecryptedForwardedRoomKeyEvent, DecryptedSecretSendEvent}, @@ -190,7 +190,10 @@ impl GossipMachine { /// Handle all the incoming key requests that are queued up and empty our /// key request queue. - pub async fn collect_incoming_key_requests(&self) -> OlmResult> { + pub async fn collect_incoming_key_requests( + &self, + cache: &StoreCache, + ) -> OlmResult> { let mut changed_sessions = Vec::new(); let incoming_key_requests = @@ -199,8 +202,8 @@ impl GossipMachine { for event in incoming_key_requests.values() { if let Some(s) = match event { #[cfg(feature = "automatic-room-key-forwarding")] - RequestEvent::KeyShare(e) => Box::pin(self.handle_key_request(e)).await?, - RequestEvent::Secret(e) => Box::pin(self.handle_secret_request(e)).await?, + RequestEvent::KeyShare(e) => Box::pin(self.handle_key_request(cache, e)).await?, + RequestEvent::Secret(e) => Box::pin(self.handle_secret_request(cache, e)).await?, #[cfg(not(feature = "automatic-room-key-forwarding"))] _ => None, } { @@ -256,6 +259,7 @@ impl GossipMachine { async fn handle_secret_request( &self, + cache: &StoreCache, event: &SecretRequestEvent, ) -> OlmResult> { let secret_name = match &event.content.action { @@ -331,7 +335,7 @@ impl GossipMachine { ?secret_name, "Received a secret request form an unknown device", ); - self.inner.store.mark_user_as_changed(&event.sender).await?; + cache.mark_user_as_changed(&self.inner.store, &event.sender).await?; None }) @@ -380,6 +384,7 @@ impl GossipMachine { #[cfg(feature = "automatic-room-key-forwarding")] async fn answer_room_key_request( &self, + cache: &StoreCache, event: &RoomKeyRequestEvent, session: &InboundGroupSession, ) -> OlmResult> { @@ -390,7 +395,7 @@ impl GossipMachine { let Some(device) = device else { warn!("Received a key request from an unknown device"); - self.inner.store.mark_user_as_changed(&event.sender).await?; + cache.mark_user_as_changed(&self.inner.store, &event.sender).await?; return Ok(None); }; @@ -429,6 +434,7 @@ impl GossipMachine { )] async fn handle_supported_key_request( &self, + cache: &StoreCache, event: &RoomKeyRequestEvent, room_id: &RoomId, session_id: &str, @@ -436,7 +442,7 @@ impl GossipMachine { let session = self.inner.store.get_inbound_group_session(room_id, session_id).await?; if let Some(s) = session { - self.answer_room_key_request(event, &s).await + self.answer_room_key_request(cache, event, &s).await } else { debug!("Received a room key request for an unknown inbound group session",); @@ -446,14 +452,19 @@ impl GossipMachine { /// Handle a single incoming key request. #[cfg(feature = "automatic-room-key-forwarding")] - async fn handle_key_request(&self, event: &RoomKeyRequestEvent) -> OlmResult> { + async fn handle_key_request( + &self, + cache: &StoreCache, + event: &RoomKeyRequestEvent, + ) -> OlmResult> { use crate::types::events::room_key_request::{Action, RequestedKeyInfo}; if self.inner.room_key_forwarding_enabled.load(Ordering::SeqCst) { match &event.content.action { Action::Request(info) => match info { RequestedKeyInfo::MegolmV1AesSha2(i) => { - self.handle_supported_key_request(event, &i.room_id, &i.session_id).await + self.handle_supported_key_request(cache, event, &i.room_id, &i.session_id) + .await } #[cfg(feature = "experimental-algorithms")] RequestedKeyInfo::MegolmV2AesSha2(i) => { @@ -832,6 +843,7 @@ impl GossipMachine { async fn receive_secret( &self, + cache: &StoreCache, sender_key: Curve25519PublicKey, secret: GossippedSecret, changes: &mut Changes, @@ -850,7 +862,7 @@ impl GossipMachine { } else { warn!("Received a m.secret.send event from an unknown device"); - self.inner.store.mark_user_as_changed(&secret.event.sender).await?; + cache.mark_user_as_changed(&self.inner.store, &secret.event.sender).await?; } Ok(()) @@ -859,6 +871,7 @@ impl GossipMachine { #[instrument(skip_all, fields(sender_key, sender = ?event.sender, request_id = ?event.content.request_id, secret_name))] pub async fn receive_secret_event( &self, + cache: &StoreCache, sender_key: Curve25519PublicKey, event: &DecryptedSecretSendEvent, changes: &mut Changes, @@ -887,7 +900,7 @@ impl GossipMachine { gossip_request: request, }; - self.receive_secret(sender_key, secret, changes).await?; + self.receive_secret(cache, sender_key, secret, changes).await?; Some(secret_name) } @@ -1069,7 +1082,7 @@ mod tests { identities::{LocalTrust, ReadOnlyDevice}, olm::{Account, PrivateCrossSigningIdentity}, session_manager::GroupSessionCache, - store::{CryptoStoreWrapper, MemoryStore, Store}, + store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store}, types::events::room::encrypted::{EncryptedEvent, RoomEncryptedEventContent}, verification::VerificationMachine, }; @@ -1113,7 +1126,7 @@ mod tests { } #[cfg(feature = "automatic-room-key-forwarding")] - fn test_gossip_machine(user_id: &UserId) -> GossipMachine { + async fn gossip_machine_test_helper(user_id: &UserId) -> GossipMachine { let user_id = user_id.to_owned(); let device_id = DeviceId::new(); @@ -1122,27 +1135,29 @@ mod tests { let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id()))); let verification = VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone()); - let store = Store::new(account, identity, store, verification); + let store = Store::new(account.static_data().clone(), identity, store, verification); + store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); + let session_cache = GroupSessionCache::new(store.clone()); GossipMachine::new(store, session_cache, Default::default()) } - async fn get_machine() -> GossipMachine { + async fn get_machine_test_helper() -> GossipMachine { let user_id = alice_id().to_owned(); let account = Account::with_device_id(&user_id, alice_device_id()); - let device = ReadOnlyDevice::from_account(&account).await; + let device = ReadOnlyDevice::from_account(&account); let another_device = - ReadOnlyDevice::from_account(&Account::with_device_id(&user_id, alice2_device_id())) - .await; + ReadOnlyDevice::from_account(&Account::with_device_id(&user_id, alice2_device_id())); let store = Arc::new(CryptoStoreWrapper::new(&user_id, MemoryStore::new())); let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id()))); let verification = VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone()); - let store = Store::new(account, identity, store, verification); + let store = Store::new(account.static_data().clone(), identity, store, verification); store.save_devices(&[device, another_device]).await.unwrap(); + store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); let session_cache = GroupSessionCache::new(store.clone()); GossipMachine::new(store, session_cache, Default::default()) @@ -1154,19 +1169,20 @@ mod tests { create_sessions: bool, algorithm: EventEncryptionAlgorithm, ) -> (GossipMachine, OutboundGroupSession, GossipMachine) { - let alice_machine = get_machine().await; - let alice_device = - ReadOnlyDevice::from_account(&alice_machine.inner.store.cache().await.unwrap().account) - .await; + let alice_machine = get_machine_test_helper().await; + let alice_device = ReadOnlyDevice::from_account( + &alice_machine.inner.store.cache().await.unwrap().account().await.unwrap(), + ); - let bob_machine = test_gossip_machine(other_machine_owner); + let bob_machine = gossip_machine_test_helper(other_machine_owner).await; - let bob_device = - ReadOnlyDevice::from_account(&bob_machine.inner.store.cache().await.unwrap().account) - .await; + let bob_device = ReadOnlyDevice::from_account( + #[allow(clippy::explicit_auto_deref)] // clippy's wrong + &*bob_machine.inner.store.cache().await.unwrap().account().await.unwrap(), + ); // We need a trusted device, otherwise we won't request keys - let second_device = ReadOnlyDevice::from_account(&alice_2_account()).await; + let second_device = ReadOnlyDevice::from_account(&alice_2_account()); second_device.set_trust_state(LocalTrust::Verified); bob_device.set_trust_state(LocalTrust::Verified); alice_machine.inner.store.save_devices(&[bob_device, second_device]).await.unwrap(); @@ -1293,14 +1309,14 @@ mod tests { #[async_test] async fn create_machine() { - let machine = get_machine().await; + let machine = get_machine_test_helper().await; assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty()); } #[async_test] async fn re_request_keys() { - let machine = get_machine().await; + let machine = get_machine_test_helper().await; let account = account(); let (outbound, session) = account.create_group_session_pair_with_defaults(room_id()).await; @@ -1323,10 +1339,10 @@ mod tests { #[async_test] #[cfg(feature = "automatic-room-key-forwarding")] async fn create_key_request() { - let machine = get_machine().await; + let machine = get_machine_test_helper().await; let account = account(); let second_account = alice_2_account(); - let alice_device = ReadOnlyDevice::from_account(&second_account).await; + let alice_device = ReadOnlyDevice::from_account(&second_account); // We need a trusted device, otherwise we won't request keys alice_device.set_trust_state(LocalTrust::Verified); @@ -1355,11 +1371,11 @@ mod tests { #[async_test] #[cfg(feature = "automatic-room-key-forwarding")] async fn receive_forwarded_key() { - let machine = get_machine().await; + let machine = get_machine_test_helper().await; let account = account(); let second_account = alice_2_account(); - let alice_device = ReadOnlyDevice::from_account(&second_account).await; + let alice_device = ReadOnlyDevice::from_account(&second_account); // We need a trusted device, otherwise we won't request keys alice_device.set_trust_state(LocalTrust::Verified); @@ -1464,7 +1480,7 @@ mod tests { #[async_test] #[cfg(feature = "automatic-room-key-forwarding")] async fn should_share_key_test() { - let machine = get_machine().await; + let machine = get_machine_test_helper().await; let account = account(); let own_device = @@ -1481,7 +1497,7 @@ mod tests { // Now we do want to share the keys. machine.should_share_key(&own_device, &inbound).await.unwrap(); - let bob_device = ReadOnlyDevice::from_account(&bob_account()).await; + let bob_device = ReadOnlyDevice::from_account(&bob_account()); machine.inner.store.save_devices(&[bob_device]).await.unwrap(); let bob_device = @@ -1539,7 +1555,7 @@ mod tests { // Finally, let's ensure we don't share the session with a device that rotated // its curve25519 key. - let bob_device = ReadOnlyDevice::from_account(&bob_account()).await; + let bob_device = ReadOnlyDevice::from_account(&bob_account()); machine.inner.store.save_devices(&[bob_device]).await.unwrap(); let bob_device = @@ -1594,7 +1610,11 @@ mod tests { // Receive the room key request from alice. bob_machine.receive_incoming_key_request(&event); - bob_machine.collect_incoming_key_requests().await.unwrap(); + + { + let bob_cache = bob_machine.inner.store.cache().await.unwrap(); + bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap(); + } // Now bob does have an outgoing request. assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty()); @@ -1672,7 +1692,10 @@ mod tests { // Receive the room key request from alice. bob_machine.receive_incoming_key_request(&event); - bob_machine.collect_incoming_key_requests().await.unwrap(); + { + let bob_cache = bob_machine.inner.store.cache().await.unwrap(); + bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap(); + } // Now bob does have an outgoing request. assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty()); @@ -1731,13 +1754,13 @@ mod tests { #[async_test] async fn test_secret_share_cycle() { - let alice_machine = get_machine().await; + let alice_machine = get_machine_test_helper().await; - let second_account = alice_2_account(); - let alice_device = ReadOnlyDevice::from_account(&second_account).await; + let mut second_account = alice_2_account(); + let alice_device = ReadOnlyDevice::from_account(&second_account); let bob_account = bob_account(); - let bob_device = ReadOnlyDevice::from_account(&bob_account).await; + let bob_device = ReadOnlyDevice::from_account(&bob_account); alice_machine.inner.store.save_devices(&[alice_device.clone()]).await.unwrap(); @@ -1747,7 +1770,8 @@ mod tests { .store .with_transaction(|mut tr| async { let alice_account = tr.account().await?; - let (alice_session, _) = alice_account.create_session_for(&second_account).await; + let (alice_session, _) = + alice_account.create_session_for(&mut second_account).await; Ok((tr, alice_session)) }) .await @@ -1767,13 +1791,19 @@ mod tests { // No secret found assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); alice_machine.receive_incoming_secret_request(&event); - alice_machine.collect_incoming_key_requests().await.unwrap(); + { + let alice_cache = alice_machine.inner.store.cache().await.unwrap(); + alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap(); + } assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); // No device found alice_machine.inner.store.reset_cross_signing_identity().await; alice_machine.receive_incoming_secret_request(&event); - alice_machine.collect_incoming_key_requests().await.unwrap(); + { + let alice_cache = alice_machine.inner.store.cache().await.unwrap(); + alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap(); + } assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); alice_machine.inner.store.save_devices(&[bob_device]).await.unwrap(); @@ -1781,7 +1811,10 @@ mod tests { // The device doesn't belong to us alice_machine.inner.store.reset_cross_signing_identity().await; alice_machine.receive_incoming_secret_request(&event); - alice_machine.collect_incoming_key_requests().await.unwrap(); + { + let alice_cache = alice_machine.inner.store.cache().await.unwrap(); + alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap(); + } assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); let event = RumaToDeviceEvent { @@ -1795,7 +1828,10 @@ mod tests { // The device isn't trusted alice_machine.receive_incoming_secret_request(&event); - alice_machine.collect_incoming_key_requests().await.unwrap(); + { + let alice_cache = alice_machine.inner.store.cache().await.unwrap(); + alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap(); + } assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); // We need a trusted device, otherwise we won't serve secrets @@ -1803,13 +1839,16 @@ mod tests { alice_machine.inner.store.save_devices(&[alice_device.clone()]).await.unwrap(); alice_machine.receive_incoming_secret_request(&event); - alice_machine.collect_incoming_key_requests().await.unwrap(); + { + let alice_cache = alice_machine.inner.store.cache().await.unwrap(); + alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap(); + } assert!(!alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); } #[async_test] #[cfg(feature = "backups_v1")] - async fn secret_broadcasting() { + async fn test_secret_broadcasting() { use futures_util::{pin_mut, FutureExt}; use ruma::api::client::to_device::send_event_to_device::v3::Response as ToDeviceResponse; use serde_json::value::to_raw_value; @@ -1872,7 +1911,15 @@ mod tests { .await .unwrap(); alice_machine.inner.key_request_machine.receive_incoming_secret_request(&event); - alice_machine.inner.key_request_machine.collect_incoming_key_requests().await.unwrap(); + { + let alice_cache = alice_machine.store().cache().await.unwrap(); + alice_machine + .inner + .key_request_machine + .collect_incoming_key_requests(&alice_cache) + .await + .unwrap(); + } let requests = alice_machine.inner.key_request_machine.outgoing_to_device_requests().await.unwrap(); @@ -1925,7 +1972,10 @@ mod tests { // Receive the room key request from alice. bob_machine.receive_incoming_key_request(&event); - bob_machine.collect_incoming_key_requests().await.unwrap(); + { + let bob_cache = bob_machine.inner.store.cache().await.unwrap(); + bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap(); + } // Bob only has a keys claim request, since we're lacking a session assert_eq!(bob_machine.outgoing_to_device_requests().await.unwrap().len(), 1); assert_matches!( @@ -1960,7 +2010,10 @@ mod tests { bob_machine.retry_keyshare(alice_id(), alice_device_id()); assert!(bob_machine.inner.users_for_key_claim.read().unwrap().is_empty()); - bob_machine.collect_incoming_key_requests().await.unwrap(); + { + let bob_cache = bob_machine.inner.store.cache().await.unwrap(); + bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap(); + } // Bob now has an outgoing requests. assert!(!bob_machine.outgoing_to_device_requests().await.unwrap().is_empty()); assert!(bob_machine.inner.wait_queue.is_empty()); diff --git a/crates/matrix-sdk-crypto/src/identities/device.rs b/crates/matrix-sdk-crypto/src/identities/device.rs index 09631ed1347..7de6622c18f 100644 --- a/crates/matrix-sdk-crypto/src/identities/device.rs +++ b/crates/matrix-sdk-crypto/src/identities/device.rs @@ -869,7 +869,7 @@ impl ReadOnlyDevice { pub async fn from_machine_test_helper( machine: &OlmMachine, ) -> Result { - Ok(ReadOnlyDevice::from_account(&machine.store().cache().await?.account).await) + Ok(ReadOnlyDevice::from_account(&*machine.store().cache().await?.account().await?)) } /// Create a `ReadOnlyDevice` from an `Account` @@ -884,8 +884,8 @@ impl ReadOnlyDevice { /// *Don't* use this after we received a keys/query response, other /// users/devices might add signatures to our own device, which can't be /// replicated locally. - pub async fn from_account(account: &Account) -> ReadOnlyDevice { - let device_keys = account.device_keys().await; + pub fn from_account(account: &Account) -> ReadOnlyDevice { + let device_keys = account.device_keys(); let mut device = ReadOnlyDevice::try_from(&device_keys) .expect("Creating a device from our own account should always succeed"); device.first_time_seen_ts = account.creation_local_time(); diff --git a/crates/matrix-sdk-crypto/src/identities/manager.rs b/crates/matrix-sdk-crypto/src/identities/manager.rs index 44fa4459aa9..56c7ea1b279 100644 --- a/crates/matrix-sdk-crypto/src/identities/manager.rs +++ b/crates/matrix-sdk-crypto/src/identities/manager.rs @@ -37,7 +37,7 @@ use crate::{ requests::KeysQueryRequest, store::{ caches::SequenceNumber, Changes, DeviceChanges, IdentityChanges, Result as StoreResult, - Store, + Store, StoreCache, }, types::{CrossSigningKey, DeviceKeys, MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, utilities::FailuresCache, @@ -167,7 +167,10 @@ impl IdentityManager { if let Some(sequence_number) = sequence_number { self.store + .cache() + .await? .mark_tracked_users_as_up_to_date( + &self.store, response.device_keys.keys().map(Deref::deref), sequence_number, ) @@ -697,13 +700,15 @@ impl IdentityManager { // tracking ourselves. // // The check for emptiness is done first for performance. - let (users, sequence_number) = - if users.is_empty() && !self.store.tracked_users().await?.contains(self.user_id()) { - self.store.mark_user_as_changed(self.user_id()).await?; + let (users, sequence_number) = { + let cache = self.store.cache().await?; + if users.is_empty() && !cache.tracked_users().contains(self.user_id()) { + cache.mark_user_as_changed(&self.store, self.user_id()).await?; self.store.users_for_key_query().await? } else { (users, sequence_number) - }; + } + }; if users.is_empty() { Ok(BTreeMap::new()) @@ -755,9 +760,10 @@ impl IdentityManager { /// key query. pub async fn receive_device_changes( &self, + cache: &StoreCache, users: impl Iterator, ) -> StoreResult<()> { - self.store.mark_tracked_users_as_changed(users).await + cache.mark_tracked_users_as_changed(&self.store, users).await } /// See the docs for [`OlmMachine::update_tracked_users()`]. @@ -765,7 +771,7 @@ impl IdentityManager { &self, users: impl IntoIterator, ) -> StoreResult<()> { - self.store.update_tracked_users(users.into_iter()).await + self.store.cache().await?.update_tracked_users(&self.store, users.into_iter()).await } } @@ -785,7 +791,7 @@ pub(crate) mod testing { identities::IdentityManager, machine::testing::response_from_file, olm::{Account, PrivateCrossSigningIdentity}, - store::{CryptoStoreWrapper, MemoryStore, Store}, + store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store}, types::DeviceKeys, verification::VerificationMachine, UploadSigningKeysRequest, @@ -803,15 +809,20 @@ pub(crate) mod testing { device_id!("WSKKLTJZCL") } - pub(crate) async fn manager(user_id: &UserId, device_id: &DeviceId) -> IdentityManager { + pub(crate) async fn manager_test_helper( + user_id: &UserId, + device_id: &DeviceId, + ) -> IdentityManager { let identity = PrivateCrossSigningIdentity::new(user_id.into()).await; let identity = Arc::new(Mutex::new(identity)); let user_id = user_id.to_owned(); let account = Account::with_device_id(&user_id, device_id); + let static_account = account.static_data().clone(); let store = Arc::new(CryptoStoreWrapper::new(&user_id, MemoryStore::new())); let verification = - VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone()); - let store = Store::new(account, identity, store, verification); + VerificationMachine::new(static_account.clone(), identity.clone(), store.clone()); + let store = Store::new(static_account, identity, store, verification); + store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); IdentityManager::new(store) } @@ -1086,7 +1097,9 @@ pub(crate) mod tests { use serde_json::json; use stream_assert::{assert_closed, assert_pending, assert_ready}; - use super::testing::{device_id, key_query, manager, other_key_query, other_user_id, user_id}; + use super::testing::{ + device_id, key_query, manager_test_helper, other_key_query, other_user_id, user_id, + }; use crate::{ identities::manager::testing::{other_key_query_cross_signed, own_key_query}, olm::PrivateCrossSigningIdentity, @@ -1111,18 +1124,22 @@ pub(crate) mod tests { #[async_test] async fn test_tracked_users() { - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; let alice = user_id!("@alice:example.org"); - assert!( - manager.store.tracked_users().await.unwrap().is_empty(), - "No users are initially tracked" - ); - manager.receive_device_changes([alice].iter().map(Deref::deref)).await.unwrap(); - assert!( - !manager.store.tracked_users().await.unwrap().contains(alice), - "Receiving a device changes update for a user we don't track does nothing" - ); + { + let cache = manager.store.cache().await.unwrap(); + + assert!(cache.tracked_users().is_empty(), "No users are initially tracked"); + + manager.receive_device_changes(&cache, [alice].iter().map(Deref::deref)).await.unwrap(); + + assert!( + !cache.tracked_users().contains(alice), + "Receiving a device changes update for a user we don't track does nothing" + ); + } + assert!( !manager.store.users_for_key_query().await.unwrap().0.contains(alice), "The user we don't track doesn't end up in the `/keys/query` request" @@ -1131,13 +1148,13 @@ pub(crate) mod tests { #[async_test] async fn test_manager_creation() { - let manager = manager(user_id(), device_id()).await; - assert!(manager.store.tracked_users().await.unwrap().is_empty()) + let manager = manager_test_helper(user_id(), device_id()).await; + assert!(manager.store.cache().await.unwrap().tracked_users().is_empty()) } #[async_test] async fn test_manager_key_query_response() { - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; let other_user = other_user_id(); let devices = manager.store.get_user_devices(other_user).await.unwrap(); assert_eq!(devices.devices().count(), 0); @@ -1164,7 +1181,7 @@ pub(crate) mod tests { #[async_test] async fn test_manager_own_key_query_response() { - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; let our_user = user_id(); let devices = manager.store.get_user_devices(our_user).await.unwrap(); assert_eq!(devices.devices().count(), 0); @@ -1174,7 +1191,8 @@ pub(crate) mod tests { let identity_request = private_identity.as_upload_request().await; drop(private_identity); - let device_keys = manager.store.cache().await.unwrap().account.device_keys().await; + let device_keys = + manager.store.cache().await.unwrap().account().await.unwrap().device_keys(); manager .receive_keys_query_response( &TransactionId::new(), @@ -1183,8 +1201,13 @@ pub(crate) mod tests { .await .unwrap(); - let identity = manager.store.get_user_identity(our_user).await.unwrap().unwrap(); - let identity = identity.own().unwrap(); + let identity = manager + .store + .get_user_identity(our_user) + .await + .unwrap() + .expect("missing user identity"); + let identity = identity.own().expect("missing own identity"); assert!(identity.is_verified()); let devices = manager.store.get_user_devices(our_user).await.unwrap(); @@ -1197,9 +1220,9 @@ pub(crate) mod tests { } #[async_test] - async fn private_identity_invalidation_after_public_keys_change() { + async fn test_private_identity_invalidation_after_public_keys_change() { let user_id = user_id!("@example1:localhost"); - let manager = manager(user_id, "DEVICEID".into()).await; + let manager = manager_test_helper(user_id, "DEVICEID".into()).await; let identity_request = { let private_identity = manager.store.private_identity(); @@ -1286,11 +1309,11 @@ pub(crate) mod tests { } #[async_test] - async fn no_tracked_users_key_query_request() { - let manager = manager(user_id(), device_id()).await; + async fn test_no_tracked_users_key_query_request() { + let manager = manager_test_helper(user_id(), device_id()).await; assert!( - manager.store.tracked_users().await.unwrap().is_empty(), + manager.store.cache().await.unwrap().tracked_users().is_empty(), "No users are initially tracked" ); @@ -1298,7 +1321,7 @@ pub(crate) mod tests { assert!(!requests.is_empty(), "We query the keys for our own user"); assert!( - manager.store.tracked_users().await.unwrap().contains(manager.user_id()), + manager.store.cache().await.unwrap().tracked_users().contains(manager.user_id()), "Our own user is now tracked" ); } @@ -1307,8 +1330,8 @@ pub(crate) mod tests { /// user is not removed from the list of outdated users when the /// response is received #[async_test] - async fn invalidation_race_handling() { - let manager = manager(user_id(), device_id()).await; + async fn test_invalidation_race_handling() { + let manager = manager_test_helper(user_id(), device_id()).await; let alice = other_user_id(); manager.update_tracked_users([alice]).await.unwrap(); @@ -1317,7 +1340,10 @@ pub(crate) mod tests { assert!(req.device_keys.contains_key(alice)); // another invalidation turns up - manager.receive_device_changes([alice].into_iter()).await.unwrap(); + { + let cache = manager.store.cache().await.unwrap(); + manager.receive_device_changes(&cache, [alice].into_iter()).await.unwrap(); + } // the response from the query arrives manager.receive_keys_query_response(&reqid, &other_key_query()).await.unwrap(); @@ -1335,20 +1361,22 @@ pub(crate) mod tests { } #[async_test] - async fn failure_handling() { - let manager = manager(user_id(), device_id()).await; + async fn test_failure_handling() { + let manager = manager_test_helper(user_id(), device_id()).await; let alice = user_id!("@alice:example.org"); - assert!( - manager.store.tracked_users().await.unwrap().is_empty(), - "No users are initially tracked" - ); - manager.store.mark_user_as_changed(alice).await.unwrap(); + { + let cache = manager.store.cache().await.unwrap(); + assert!(cache.tracked_users().is_empty(), "No users are initially tracked"); + + cache.mark_user_as_changed(&manager.store, alice).await.unwrap(); + + assert!( + cache.tracked_users().contains(alice), + "Alice is tracked after being marked as tracked" + ); + } - assert!( - manager.store.tracked_users().await.unwrap().contains(alice), - "Alice is tracked after being marked as tracked" - ); let (reqid, req) = manager.users_for_key_query().await.unwrap().pop_first().unwrap(); assert!(req.device_keys.contains_key(alice)); @@ -1376,7 +1404,7 @@ pub(crate) mod tests { #[async_test] async fn test_out_of_band_key_query() { // build the request - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; let (reqid, req) = manager.build_key_query_for_users(vec![user_id()]); assert!(req.device_keys.contains_key(user_id())); @@ -1394,8 +1422,8 @@ pub(crate) mod tests { } #[async_test] - async fn devices_stream() { - let manager = manager(user_id(), device_id()).await; + async fn test_devices_stream() { + let manager = manager_test_helper(user_id(), device_id()).await; let (request_id, _) = manager.build_key_query_for_users(vec![user_id()]); let stream = manager.store.devices_stream(); @@ -1408,8 +1436,8 @@ pub(crate) mod tests { } #[async_test] - async fn identities_stream() { - let manager = manager(user_id(), device_id()).await; + async fn test_identities_stream() { + let manager = manager_test_helper(user_id(), device_id()).await; let (request_id, _) = manager.build_key_query_for_users(vec![user_id()]); let stream = manager.store.user_identities_stream(); @@ -1422,8 +1450,8 @@ pub(crate) mod tests { } #[async_test] - async fn identities_stream_raw() { - let mut manager = Some(manager(user_id(), device_id()).await); + async fn test_identities_stream_raw() { + let mut manager = Some(manager_test_helper(user_id(), device_id()).await); let (request_id, _) = manager.as_ref().unwrap().build_key_query_for_users(vec![user_id()]); let stream = manager.as_ref().unwrap().store.identities_stream_raw(); @@ -1467,8 +1495,8 @@ pub(crate) mod tests { } #[async_test] - async fn identities_stream_raw_signature_update() { - let mut manager = Some(manager(user_id(), device_id()).await); + async fn test_identities_stream_raw_signature_update() { + let mut manager = Some(manager_test_helper(user_id(), device_id()).await); let (request_id, _) = manager.as_ref().unwrap().build_key_query_for_users(vec![other_user_id()]); diff --git a/crates/matrix-sdk-crypto/src/identities/user.rs b/crates/matrix-sdk-crypto/src/identities/user.rs index 4d7a2873f60..707cf0de505 100644 --- a/crates/matrix-sdk-crypto/src/identities/user.rs +++ b/crates/matrix-sdk-crypto/src/identities/user.rs @@ -146,8 +146,9 @@ impl OwnUserIdentity { error!(error = ?e, "Couldn't store our own user identity after marking it as verified"); } - let account = &self.store.cache().await?.account; - account.sign_master_key(self.master_key.clone()).await + let cache = self.store.cache().await?; + let account = cache.account().await?; + account.sign_master_key(self.master_key.clone()) } /// Send a verification request to our other devices. @@ -174,7 +175,7 @@ impl OwnUserIdentity { /// own device keys with our self-signing key. pub async fn trusts_our_own_device(&self) -> Result { Ok(if let Some(signatures) = self.verification_machine.store.device_signatures().await? { - let mut device_keys = self.store.cache().await?.account.device_keys().await; + let mut device_keys = self.store.cache().await?.account().await?.device_keys(); device_keys.signatures = signatures; self.inner.self_signing_key().verify_device_keys(&device_keys).is_ok() diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index e8e771d2ae0..62982ce32e6 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -64,12 +64,13 @@ use crate::{ olm::{ Account, CrossSigningStatus, EncryptionSettings, ExportedRoomKey, IdentityKeys, InboundGroupSession, OlmDecryptionInfo, PrivateCrossSigningIdentity, SessionType, + StaticAccountData, }, requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest}, session_manager::{GroupSessionManager, SessionManager}, store::{ Changes, CryptoStoreWrapper, DeviceChanges, IdentityChanges, IntoCryptoStore, MemoryStore, - PendingChanges, Result as StoreResult, RoomKeyInfo, SecretImportError, Store, + PendingChanges, Result as StoreResult, RoomKeyInfo, SecretImportError, Store, StoreCache, StoreTransaction, }, types::{ @@ -168,25 +169,23 @@ impl OlmMachine { ) -> Result { let account = Account::rehydrate(pickle_key, self.user_id(), device_id, device_data).await?; + let static_account = account.static_data().clone(); let store = Arc::new(CryptoStoreWrapper::new(self.user_id(), MemoryStore::new())); + store.save_pending_changes(PendingChanges { account: Some(account) }).await?; - Ok(Self::new_helper(device_id, store, account, self.store().private_identity())) + Ok(Self::new_helper(device_id, store, static_account, self.store().private_identity())) } fn new_helper( device_id: &DeviceId, store: Arc, - account: Account, + account: StaticAccountData, user_identity: Arc>, ) -> Self { - let verification_machine = VerificationMachine::new( - account.static_data().clone(), - user_identity.clone(), - store.clone(), - ); - let store = - Store::new(account.clone(), user_identity.clone(), store, verification_machine.clone()); + let verification_machine = + VerificationMachine::new(account.clone(), user_identity.clone(), store.clone()); + let store = Store::new(account, user_identity.clone(), store, verification_machine.clone()); let group_session_manager = GroupSessionManager::new(store.clone()); @@ -248,7 +247,8 @@ impl OlmMachine { store: impl IntoCryptoStore, ) -> StoreResult { let store = store.into_crypto_store(); - let account = match store.load_account().await? { + + let static_account = match store.load_account().await? { Some(account) => { if user_id != account.user_id() || device_id != account.device_id() { return Err(CryptoStoreError::MismatchedAccount { @@ -262,16 +262,18 @@ impl OlmMachine { .record("curve25519_key", display(account.identity_keys().curve25519)); debug!("Restored an Olm account"); - account + account.static_data().clone() } + None => { let account = Account::with_device_id(user_id, device_id); + let static_account = account.static_data().clone(); Span::current() .record("ed25519_key", display(account.identity_keys().ed25519)) .record("curve25519_key", display(account.identity_keys().curve25519)); - let device = ReadOnlyDevice::from_account(&account).await; + let device = ReadOnlyDevice::from_account(&account); // We just created this device from our own Olm `Account`. Since we are the // owners of the private keys of this device we can safely mark @@ -283,13 +285,11 @@ impl OlmMachine { ..Default::default() }; store.save_changes(changes).await?; - store - .save_pending_changes(PendingChanges { account: Some(account.clone()) }) - .await?; + store.save_pending_changes(PendingChanges { account: Some(account) }).await?; debug!("Created a new Olm account"); - account + static_account } }; @@ -310,7 +310,7 @@ impl OlmMachine { let identity = Arc::new(Mutex::new(identity)); let store = Arc::new(CryptoStoreWrapper::new(user_id, store)); - Ok(OlmMachine::new_helper(device_id, store, account, identity)) + Ok(OlmMachine::new_helper(device_id, store, static_account, identity)) } /// Get the crypto store associated with this `OlmMachine` instance. @@ -344,7 +344,7 @@ impl OlmMachine { /// See [`update_tracked_users`](#method.update_tracked_users) for more /// information. pub async fn tracked_users(&self) -> StoreResult> { - self.store().tracked_users().await + Ok(self.store().cache().await?.tracked_users()) } /// Enable or disable room key forwarding. @@ -504,7 +504,8 @@ impl OlmMachine { if reset || identity.is_empty().await { info!("Creating new cross signing identity"); - let account = &self.inner.store.cache().await?.account; + let cache = self.inner.store.cache().await?; + let account = cache.account().await?; let (id, request, signature_request) = account.bootstrap_cross_signing().await; *identity = id; @@ -533,13 +534,6 @@ impl OlmMachine { } } - /// Get the underlying Olm account of the machine. - #[cfg(any(test, feature = "testing"))] - #[allow(dead_code)] - pub(crate) async fn account(&self) -> Result { - Ok(self.inner.store.cache().await?.account.clone()) - } - /// Receive a successful keys upload response. /// /// # Arguments @@ -554,7 +548,7 @@ impl OlmMachine { .store .with_transaction(|mut tr| async { let account = tr.account().await?; - account.receive_keys_upload_response(response).await?; + account.receive_keys_upload_response(response)?; Ok((tr, ())) }) .await @@ -630,7 +624,7 @@ impl OlmMachine { /// /// [`receive_keys_upload_response`]: #method.receive_keys_upload_response async fn keys_for_upload(&self, account: &Account) -> Option { - let (device_keys, one_time_keys, fallback_keys) = account.keys_for_upload().await; + let (device_keys, one_time_keys, fallback_keys) = account.keys_for_upload(); if device_keys.is_none() && one_time_keys.is_empty() && fallback_keys.is_empty() { None @@ -662,7 +656,7 @@ impl OlmMachine { // Handle the decrypted event, e.g. fetch out Megolm sessions out of // the event. - self.handle_decrypted_to_device_event(&mut decrypted, changes).await?; + self.handle_decrypted_to_device_event(transaction.cache(), &mut decrypted, changes).await?; Ok(decrypted) } @@ -902,6 +896,7 @@ impl OlmMachine { )] async fn handle_decrypted_to_device_event( &self, + cache: &StoreCache, decrypted: &mut OlmDecryptionInfo, changes: &mut Changes, ) -> OlmResult<()> { @@ -924,7 +919,7 @@ impl OlmMachine { let name = self .inner .key_request_machine - .receive_secret_event(decrypted.result.sender_key, e, changes) + .receive_secret_event(cache, decrypted.result.sender_key, e, changes) .await?; // Set the secret name so other consumers of the event know @@ -1166,18 +1161,19 @@ impl OlmMachine { { let account = transaction.account().await?; - account - .update_key_counts( - sync_changes.one_time_keys_counts, - sync_changes.unused_fallback_keys, - ) - .await; + account.update_key_counts( + sync_changes.one_time_keys_counts, + sync_changes.unused_fallback_keys, + ) } if let Err(e) = self .inner .identity_manager - .receive_device_changes(sync_changes.changed_devices.changed.iter().map(|u| u.as_ref())) + .receive_device_changes( + transaction.cache(), + sync_changes.changed_devices.changed.iter().map(|u| u.as_ref()), + ) .await { error!(error = ?e, "Error marking a tracked user as changed"); @@ -1190,8 +1186,11 @@ impl OlmMachine { events.push(raw_event); } - let changed_sessions = - self.inner.key_request_machine.collect_incoming_key_requests().await?; + let changed_sessions = self + .inner + .key_request_machine + .collect_incoming_key_requests(transaction.cache()) + .await?; changes.sessions.extend(changed_sessions); changes.next_batch_token = sync_changes.next_batch_token; @@ -1881,10 +1880,13 @@ impl OlmMachine { pub async fn sign(&self, message: &str) -> Result { let mut signatures = Signatures::new(); - let account = &self.inner.store.cache().await?.account; - let key_id = account.signing_key_id(); - let signature = account.sign(message).await; - signatures.add_signature(self.user_id().to_owned(), key_id, signature); + { + let cache = self.inner.store.cache().await?; + let account = cache.account().await?; + let key_id = account.signing_key_id(); + let signature = account.sign(message); + signatures.add_signature(self.user_id().to_owned(), key_id, signature); + } match self.sign_with_master_key(message).await { Ok((key_id, signature)) => { @@ -2032,7 +2034,9 @@ impl OlmMachine { /// Testing purposes only. #[cfg(any(feature = "testing", test))] pub async fn uploaded_key_count(&self) -> Result { - Ok(self.inner.store.cache().await?.account.uploaded_key_count()) + let cache = self.inner.store.cache().await?; + let account = cache.account().await?; + Ok(account.uploaded_key_count()) } } @@ -2183,13 +2187,22 @@ pub(crate) mod tests { ) -> (OlmMachine, OneTimeKeys) { let machine = OlmMachine::new(user_id, bob_device_id()).await; - let account = machine.account().await.unwrap(); - account.generate_fallback_key_helper().await; - account.update_uploaded_key_count(0); - account.generate_one_time_keys().await; + let request = machine + .store() + .with_transaction(|mut tr| async { + let account = tr.account().await.unwrap(); + account.generate_fallback_key_helper(); + account.update_uploaded_key_count(0); + account.generate_one_time_keys(); + let request = machine + .keys_for_upload(account) + .await + .expect("Can't prepare initial key upload"); + Ok((tr, request)) + }) + .await + .unwrap(); - let request = - machine.keys_for_upload(&account).await.expect("Can't prepare initial key upload"); let response = keys_upload_response(); machine.receive_keys_upload_response(&response).await.unwrap(); @@ -2198,7 +2211,7 @@ pub(crate) mod tests { (machine, keys) } - async fn get_machine_after_query() -> (OlmMachine, OneTimeKeys) { + async fn get_machine_after_query_test_helper() -> (OlmMachine, OneTimeKeys) { let (machine, otk) = get_prepared_machine_test_helper(user_id(), false).await; let response = keys_query_response(); let req_id = TransactionId::new(); @@ -2287,7 +2300,9 @@ pub(crate) mod tests { #[async_test] async fn test_create_olm_machine() { let machine = OlmMachine::new(user_id(), alice_device_id()).await; - let account = machine.account().await.unwrap(); + + let cache = machine.store().cache().await.unwrap(); + let account = cache.account().await.unwrap(); assert!(!account.shared()); let own_device = machine @@ -2302,27 +2317,59 @@ pub(crate) mod tests { #[async_test] async fn test_generate_one_time_keys() { let machine = OlmMachine::new(user_id(), alice_device_id()).await; - let account = machine.account().await.unwrap(); - assert!(account.generate_one_time_keys().await.is_some()); + machine + .store() + .with_transaction(|mut tr| async { + let account = tr.account().await.unwrap(); + assert!(account.generate_one_time_keys().is_some()); + Ok((tr, ())) + }) + .await + .unwrap(); let mut response = keys_upload_response(); machine.receive_keys_upload_response(&response).await.unwrap(); - assert!(account.generate_one_time_keys().await.is_some()); + + machine + .store() + .with_transaction(|mut tr| async { + let account = tr.account().await.unwrap(); + assert!(account.generate_one_time_keys().is_some()); + Ok((tr, ())) + }) + .await + .unwrap(); response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50)); + machine.receive_keys_upload_response(&response).await.unwrap(); - assert!(account.generate_one_time_keys().await.is_none()); + + machine + .store() + .with_transaction(|mut tr| async { + let account = tr.account().await.unwrap(); + assert!(account.generate_one_time_keys().is_none()); + + Ok((tr, ())) + }) + .await + .unwrap(); } #[async_test] async fn test_device_key_signing() { let machine = OlmMachine::new(user_id(), alice_device_id()).await; - let account = machine.account().await.unwrap(); - let device_keys = account.device_keys().await; - let identity_keys = account.identity_keys(); + let (device_keys, identity_keys) = { + let cache = machine.store().cache().await.unwrap(); + let account = cache.account().await.unwrap(); + let device_keys = account.device_keys(); + let identity_keys = account.identity_keys(); + (device_keys, identity_keys) + }; + let ed25519_key = identity_keys.ed25519; let ret = ed25519_key.verify_json( @@ -2351,11 +2398,11 @@ pub(crate) mod tests { .invalidated()); } - #[async_test] - async fn test_invalid_signature() { + #[test] + fn test_invalid_signature() { let account = Account::with_device_id(user_id(), alice_device_id()); - let device_keys = account.device_keys().await; + let device_keys = account.device_keys(); let key = Ed25519PublicKey::from_slice(&[0u8; 32]).unwrap(); @@ -2367,13 +2414,13 @@ pub(crate) mod tests { ret.unwrap_err(); } - #[async_test] - async fn test_one_time_key_signing() { - let account = Account::with_device_id(user_id(), alice_device_id()); + #[test] + fn test_one_time_key_signing() { + let mut account = Account::with_device_id(user_id(), alice_device_id()); account.update_uploaded_key_count(49); - account.generate_one_time_keys().await; + account.generate_one_time_keys(); - let mut one_time_keys = account.signed_one_time_keys().await; + let mut one_time_keys = account.signed_one_time_keys(); let ed25519_key = account.identity_keys().ed25519; let one_time_key: SignedKey = one_time_keys @@ -2395,6 +2442,7 @@ pub(crate) mod tests { #[async_test] async fn test_keys_for_upload() { let machine = OlmMachine::new(user_id(), alice_device_id()).await; + let key_counts = BTreeMap::from([(DeviceKeyAlgorithm::SignedCurve25519, 49u8.into())]); machine .receive_sync_changes(EncryptionSyncChanges { @@ -2407,11 +2455,15 @@ pub(crate) mod tests { .await .expect("We should be able to update our one-time key counts"); - let account = machine.account().await.unwrap(); - let ed25519_key = account.identity_keys().ed25519; + let (ed25519_key, mut request) = { + let cache = machine.store().cache().await.unwrap(); + let account = cache.account().await.unwrap(); + let ed25519_key = account.identity_keys().ed25519; - let mut request = - machine.keys_for_upload(&account).await.expect("Can't prepare initial key upload"); + let request = + machine.keys_for_upload(&account).await.expect("Can't prepare initial key upload"); + (ed25519_key, request) + }; let one_time_key: SignedKey = request .one_time_keys @@ -2437,16 +2489,27 @@ pub(crate) mod tests { ); ret.unwrap(); - let mut response = keys_upload_response(); - response.one_time_key_counts.insert( - DeviceKeyAlgorithm::SignedCurve25519, - (account.max_one_time_keys().await).try_into().unwrap(), - ); + let response = { + let cache = machine.store().cache().await.unwrap(); + let account = cache.account().await.unwrap(); + + let mut response = keys_upload_response(); + response.one_time_key_counts.insert( + DeviceKeyAlgorithm::SignedCurve25519, + account.max_one_time_keys().try_into().unwrap(), + ); + + response + }; machine.receive_keys_upload_response(&response).await.unwrap(); - let ret = machine.keys_for_upload(&account).await; - assert!(ret.is_none()); + { + let cache = machine.store().cache().await.unwrap(); + let account = cache.account().await.unwrap(); + let ret = machine.keys_for_upload(&account).await; + assert!(ret.is_none()); + } } #[async_test] @@ -2477,7 +2540,7 @@ pub(crate) mod tests { #[async_test] async fn test_missing_sessions_calculation() { - let (machine, _) = get_machine_after_query().await; + let (machine, _) = get_machine_after_query_test_helper().await; let alice = alice_id(); let alice_device = alice_device_id(); @@ -2523,7 +2586,7 @@ pub(crate) mod tests { let session = alice_machine .store() .get_sessions( - &bob_machine.account().await.unwrap().identity_keys().curve25519.to_base64(), + &bob_machine.store().static_account().identity_keys().curve25519.to_base64(), ) .await .unwrap() @@ -2533,7 +2596,7 @@ pub(crate) mod tests { } #[async_test] - async fn getting_most_recent_session() { + async fn test_getting_most_recent_session() { let (alice_machine, bob_machine, mut one_time_keys) = get_machine_pair(alice_id(), user_id(), false).await; let (device_key_id, one_time_key) = one_time_keys.pop_first().unwrap(); @@ -2658,12 +2721,12 @@ pub(crate) mod tests { } #[async_test] - async fn olm_encryption() { + async fn test_olm_encryption() { olm_encryption_test_helper(false).await; } #[async_test] - async fn olm_encryption_with_fallback_key() { + async fn test_olm_encryption_with_fallback_key() { olm_encryption_test_helper(true).await; } @@ -2755,7 +2818,7 @@ pub(crate) mod tests { async fn test_request_missing_secrets_cross_signed() { let (alice, bob) = get_machine_pair_with_session(alice_id(), bob_id(), false).await; - setup_cross_signing_for_machine(&alice, &bob).await; + setup_cross_signing_for_machine_test_helper(&alice, &bob).await; let should_query_secrets = alice.query_missing_secrets_from_other_sessions().await.unwrap(); @@ -3032,7 +3095,7 @@ pub(crate) mod tests { ); assert_shield!(encryption_info, Red, Red); - setup_cross_signing_for_machine(&alice, &bob).await; + setup_cross_signing_for_machine_test_helper(&alice, &bob).await; let bob_id_from_alice = alice.get_identity(bob.user_id(), None).await.unwrap(); assert_matches!(bob_id_from_alice, Some(UserIdentities::Other(_))); let alice_id_from_bob = bob.get_identity(alice.user_id(), None).await.unwrap(); @@ -3113,7 +3176,7 @@ pub(crate) mod tests { ); } - async fn setup_cross_signing_for_machine(alice: &OlmMachine, bob: &OlmMachine) { + async fn setup_cross_signing_for_machine_test_helper(alice: &OlmMachine, bob: &OlmMachine) { let (alice_upload_signing, _) = alice.bootstrap_cross_signing(false).await.expect("Expect Alice x-signing key request"); @@ -3457,7 +3520,7 @@ pub(crate) mod tests { } #[async_test] - async fn interactive_verification() { + async fn test_interactive_verification() { let (alice, bob) = get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await; @@ -3853,7 +3916,7 @@ pub(crate) mod tests { let (alice, bob) = get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await; - setup_cross_signing_for_machine(&alice, &bob).await; + setup_cross_signing_for_machine_test_helper(&alice, &bob).await; let second_alice = create_additional_machine(&alice).await; diff --git a/crates/matrix-sdk-crypto/src/olm/account.rs b/crates/matrix-sdk-crypto/src/olm/account.rs index 1c1c9e007ac..d13848603a0 100644 --- a/crates/matrix-sdk-crypto/src/olm/account.rs +++ b/crates/matrix-sdk-crypto/src/olm/account.rs @@ -15,11 +15,8 @@ use std::{ collections::{BTreeMap, HashMap}, fmt, - ops::Deref, - sync::{ - atomic::{AtomicBool, AtomicU64, Ordering}, - Arc, - }, + ops::{Deref, Not as _}, + sync::Arc, }; use ruma::{ @@ -314,19 +311,18 @@ impl StaticAccountData { /// /// An account is the central identity for encrypted communication between two /// devices. -#[derive(Clone)] pub struct Account { pub(crate) static_data: StaticAccountData, /// `vodozemac` account. - inner: Arc>, + inner: InnerAccount, /// Is this account ready to encrypt messages? (i.e. has it shared keys with /// a homeserver) - shared: Arc, + shared: bool, /// The number of signed one-time keys we have uploaded to the server. If /// this is None, no action will be taken. After a sync request the client /// needs to set this for us, depending on the count we will suggest the /// client to upload new keys. - uploaded_signed_key_count: Arc, + uploaded_signed_key_count: u64, } impl Deref for Account { @@ -374,6 +370,9 @@ impl fmt::Debug for Account { } } +pub type OneTimeKeys = BTreeMap>; +pub type FallbackKeys = OneTimeKeys; + impl Account { fn new_helper(mut account: InnerAccount, user_id: &UserId, device_id: &DeviceId) -> Self { let identity_keys = account.identity_keys(); @@ -398,9 +397,9 @@ impl Account { identity_keys: Arc::new(identity_keys), creation_local_time: MilliSecondsSinceUnixEpoch::now(), }, - inner: Arc::new(Mutex::new(account)), - shared: Arc::new(AtomicBool::new(false)), - uploaded_signed_key_count: Arc::new(AtomicU64::new(0)), + inner: account, + shared: false, + uploaded_signed_key_count: 0, } } @@ -431,47 +430,47 @@ impl Account { /// # Arguments /// /// * `new_count` - The new count that was reported by the server. - pub fn update_uploaded_key_count(&self, new_count: u64) { - self.uploaded_signed_key_count.store(new_count, Ordering::SeqCst); + pub fn update_uploaded_key_count(&mut self, new_count: u64) { + self.uploaded_signed_key_count = new_count; } /// Get the currently known uploaded key count. pub fn uploaded_key_count(&self) -> u64 { - self.uploaded_signed_key_count.load(Ordering::SeqCst) + self.uploaded_signed_key_count } /// Has the account been shared with the server. pub fn shared(&self) -> bool { - self.shared.load(Ordering::SeqCst) + self.shared } /// Mark the account as shared. /// /// Messages shouldn't be encrypted with the session before it has been /// shared. - pub fn mark_as_shared(&self) { - self.shared.store(true, Ordering::SeqCst); + pub fn mark_as_shared(&mut self) { + self.shared = true; } /// Get the one-time keys of the account. /// /// This can be empty, keys need to be generated first. - pub async fn one_time_keys(&self) -> HashMap { - self.inner.lock().await.one_time_keys() + pub fn one_time_keys(&self) -> HashMap { + self.inner.one_time_keys() } /// Generate count number of one-time keys. - pub async fn generate_one_time_keys_helper(&self, count: usize) -> OneTimeKeyGenerationResult { - self.inner.lock().await.generate_one_time_keys(count) + pub fn generate_one_time_keys_helper(&mut self, count: usize) -> OneTimeKeyGenerationResult { + self.inner.generate_one_time_keys(count) } /// Get the maximum number of one-time keys the account can hold. - pub async fn max_one_time_keys(&self) -> usize { - self.inner.lock().await.max_number_of_one_time_keys() + pub fn max_one_time_keys(&self) -> usize { + self.inner.max_number_of_one_time_keys() } - pub(crate) async fn update_key_counts( - &self, + pub(crate) fn update_key_counts( + &mut self, one_time_key_counts: &BTreeMap, unused_fallback_keys: Option<&[DeviceKeyAlgorithm]>, ) { @@ -490,13 +489,13 @@ impl Account { } self.update_uploaded_key_count(count); - self.generate_one_time_keys().await; + self.generate_one_time_keys(); } if let Some(unused) = unused_fallback_keys { if !unused.contains(&DeviceKeyAlgorithm::SignedCurve25519) { // Generate a new fallback key if we don't have one. - self.generate_fallback_key_helper().await; + self.generate_fallback_key_helper(); } } } @@ -510,13 +509,13 @@ impl Account { /// Generally `Some` means that keys should be uploaded, while `None` means /// that keys should not be uploaded. #[instrument(skip_all)] - pub async fn generate_one_time_keys(&self) -> Option { + pub fn generate_one_time_keys(&mut self) -> Option { // Only generate one-time keys if there aren't any, otherwise the caller // might have failed to upload them the last time this method was // called. - if self.one_time_keys().await.is_empty() { + if self.one_time_keys().is_empty() { let count = self.uploaded_key_count(); - let max_keys = self.max_one_time_keys().await; + let max_keys = self.max_one_time_keys(); if count >= max_keys as u64 { return None; @@ -525,7 +524,7 @@ impl Account { let key_count = (max_keys as u64) - count; let key_count: usize = key_count.try_into().unwrap_or(max_keys); - let result = self.generate_one_time_keys_helper(key_count).await; + let result = self.generate_one_time_keys_helper(key_count); debug!( count = key_count, @@ -540,11 +539,9 @@ impl Account { } } - pub(crate) async fn generate_fallback_key_helper(&self) { - let mut account = self.inner.lock().await; - - if account.fallback_key().is_empty() { - let removed_fallback_key = account.generate_fallback_key(); + pub(crate) fn generate_fallback_key_helper(&mut self) { + if self.inner.fallback_key().is_empty() { + let removed_fallback_key = self.inner.generate_fallback_key(); debug!( ?removed_fallback_key, @@ -553,8 +550,8 @@ impl Account { } } - async fn fallback_key(&self) -> HashMap { - self.inner.lock().await.fallback_key() + fn fallback_key(&self) -> HashMap { + self.inner.fallback_key() } /// Get a tuple of device, one-time, and fallback keys that need to be @@ -562,36 +559,30 @@ impl Account { /// /// If no keys need to be uploaded the `DeviceKeys` will be `None` and the /// one-time and fallback keys maps will be empty. - pub async fn keys_for_upload( - &self, - ) -> ( - Option, - BTreeMap>, - BTreeMap>, - ) { - let device_keys = if !self.shared() { Some(self.device_keys().await) } else { None }; + pub fn keys_for_upload(&self) -> (Option, OneTimeKeys, FallbackKeys) { + let device_keys = self.shared().not().then(|| self.device_keys()); - let one_time_keys = self.signed_one_time_keys().await; - let fallback_keys = self.signed_fallback_keys().await; + let one_time_keys = self.signed_one_time_keys(); + let fallback_keys = self.signed_fallback_keys(); (device_keys, one_time_keys, fallback_keys) } /// Mark the current set of one-time keys as being published. - pub async fn mark_keys_as_published(&self) { - self.inner.lock().await.mark_keys_as_published(); + pub fn mark_keys_as_published(&mut self) { + self.inner.mark_keys_as_published(); } /// Sign the given string using the accounts signing key. /// /// Returns the signature as a base64 encoded string. - pub async fn sign(&self, string: &str) -> Ed25519Signature { - self.inner.lock().await.sign(string) + pub fn sign(&self, string: &str) -> Ed25519Signature { + self.inner.sign(string) } /// Get a serializeable version of the `Account` so it can be persisted. - pub async fn pickle(&self) -> PickledAccount { - let pickle = self.inner.lock().await.pickle(); + pub fn pickle(&self) -> PickledAccount { + let pickle = self.inner.pickle(); PickledAccount { user_id: self.user_id().to_owned(), @@ -603,11 +594,11 @@ impl Account { } } - pub(crate) async fn dehydrate(&self, pickle_key: &[u8; 32]) -> Raw { - let device_pickle = - self.inner.lock().await.to_libolm_pickle(pickle_key).expect( - "We should be able to convert a freshly created Account into a libolm pickle", - ); + pub(crate) fn dehydrate(&self, pickle_key: &[u8; 32]) -> Raw { + let device_pickle = self + .inner + .to_libolm_pickle(pickle_key) + .expect("We should be able to convert a freshly created Account into a libolm pickle"); let data = DehydratedDeviceData::V1(DehydratedDeviceV1::new(device_pickle)); Raw::from_json(to_raw_value(&data).expect("Coulnd't our dehydrated device data")) @@ -653,15 +644,15 @@ impl Account { identity_keys: Arc::new(identity_keys), creation_local_time: pickle.creation_local_time, }, - inner: Arc::new(Mutex::new(account)), - shared: Arc::new(AtomicBool::from(pickle.shared)), - uploaded_signed_key_count: Arc::new(AtomicU64::new(pickle.uploaded_signed_key_count)), + inner: account, + shared: pickle.shared, + uploaded_signed_key_count: pickle.uploaded_signed_key_count, }) } /// Sign the device keys of the account and return them so they can be /// uploaded. - pub async fn device_keys(&self) -> DeviceKeys { + pub fn device_keys(&self) -> DeviceKeys { let mut device_keys = self.unsigned_device_keys(); // Create a copy of the device keys containing only fields that will @@ -670,7 +661,6 @@ impl Account { serde_json::to_value(&device_keys).expect("device key is always safe to serialize"); let signature = self .sign_json(json_device_keys) - .await .expect("Newly created device keys can always be signed"); device_keys.signatures.add_signature( @@ -690,11 +680,11 @@ impl Account { } /// Sign the given CrossSigning Key in place - pub async fn sign_cross_signing_key( + pub fn sign_cross_signing_key( &self, cross_signing_key: &mut CrossSigningKey, ) -> Result<(), SignatureError> { - let signature = self.sign_json(serde_json::to_value(&cross_signing_key)?).await?; + let signature = self.sign_json(serde_json::to_value(&cross_signing_key)?)?; cross_signing_key.signatures.add_signature( self.user_id().to_owned(), @@ -706,7 +696,7 @@ impl Account { } /// Sign the given Master Key - pub async fn sign_master_key( + pub fn sign_master_key( &self, master_key: MasterPubkey, ) -> Result { @@ -715,7 +705,7 @@ impl Account { let mut cross_signing_key: CrossSigningKey = master_key.as_ref().clone(); cross_signing_key.signatures.clear(); - self.sign_cross_signing_key(&mut cross_signing_key).await?; + self.sign_cross_signing_key(&mut cross_signing_key)?; let mut user_signed_keys = SignedKeys::new(); user_signed_keys.add_cross_signing_keys(public_key, cross_signing_key.to_raw()); @@ -731,41 +721,41 @@ impl Account { /// /// * `json` - The value that should be converted into a canonical JSON /// string. - pub async fn sign_json(&self, json: Value) -> Result { - self.inner.lock().await.sign_json(json) + pub fn sign_json(&self, json: Value) -> Result { + self.inner.sign_json(json) } /// Generate, sign and prepare one-time keys to be uploaded. /// /// If no one-time keys need to be uploaded returns an empty BTreeMap. - pub async fn signed_one_time_keys( + pub fn signed_one_time_keys( &self, ) -> BTreeMap> { - let one_time_keys = self.one_time_keys().await; + let one_time_keys = self.one_time_keys(); if one_time_keys.is_empty() { BTreeMap::new() } else { - self.signed_keys(one_time_keys, false).await + self.signed_keys(one_time_keys, false) } } /// Sign and prepare fallback keys to be uploaded. /// /// If no fallback keys need to be uploaded returns an empty BTreeMap. - pub async fn signed_fallback_keys( + pub fn signed_fallback_keys( &self, ) -> BTreeMap> { - let fallback_key = self.fallback_key().await; + let fallback_key = self.fallback_key(); if fallback_key.is_empty() { BTreeMap::new() } else { - self.signed_keys(fallback_key, true).await + self.signed_keys(fallback_key, true) } } - async fn signed_keys( + fn signed_keys( &self, keys: HashMap, fallback: bool, @@ -773,7 +763,7 @@ impl Account { let mut keys_map = BTreeMap::new(); for (key_id, key) in keys { - let signed_key = self.sign_key(key, fallback).await; + let signed_key = self.sign_key(key, fallback); keys_map.insert( DeviceKeyId::from_parts( @@ -787,7 +777,7 @@ impl Account { keys_map } - async fn sign_key(&self, key: Curve25519PublicKey, fallback: bool) -> SignedKey { + fn sign_key(&self, key: Curve25519PublicKey, fallback: bool) -> SignedKey { let mut key = if fallback { SignedKey::new_fallback(key.to_owned()) } else { @@ -796,7 +786,6 @@ impl Account { let signature = self .sign_json(serde_json::to_value(&key).expect("Can't serialize a signed key")) - .await .expect("Newly created one-time keys can always be signed"); key.signatures_mut().add_signature( @@ -822,15 +811,14 @@ impl Account { /// created and shared with us. /// /// * `fallback_used` - Was the one-time key a fallback key. - pub async fn create_outbound_session_helper( + pub fn create_outbound_session_helper( &self, config: SessionConfig, identity_key: Curve25519PublicKey, one_time_key: Curve25519PublicKey, fallback_used: bool, ) -> Session { - let session = - self.inner.lock().await.create_outbound_session(config, identity_key, one_time_key); + let session = self.inner.create_outbound_session(config, identity_key, one_time_key); let now = SecondsSinceUnixEpoch::now(); let session_id = session.session_id(); @@ -859,7 +847,8 @@ impl Account { /// /// * `key_map` - A map from the algorithm and device ID to the one-time key /// that the other account created and shared with us. - pub async fn create_outbound_session( + #[allow(clippy::result_large_err)] + pub fn create_outbound_session( &self, device: &ReadOnlyDevice, key_map: &BTreeMap>, @@ -907,9 +896,7 @@ impl Account { let one_time_key = one_time_key.key(); let config = device.olm_session_config(); - Ok(self - .create_outbound_session_helper(config, identity_key, one_time_key, is_fallback) - .await) + Ok(self.create_outbound_session_helper(config, identity_key, one_time_key, is_fallback)) } /// Create a new session with another account given a pre-key Olm message. @@ -930,20 +917,18 @@ impl Account { session, ) )] - pub async fn create_inbound_session( - &self, + pub fn create_inbound_session( + &mut self, their_identity_key: Curve25519PublicKey, message: &PreKeyMessage, ) -> Result { debug!("Creating a new Olm session from a pre-key message"); let result = - self.inner.lock().await.create_inbound_session(their_identity_key, message).map_err( - |e| { - warn!("Failed to create a new Olm session from a pre-key message: {e:?}"); - e - }, - )?; + self.inner.create_inbound_session(their_identity_key, message).map_err(|e| { + warn!("Failed to create a new Olm session from a pre-key message: {e:?}"); + e + })?; let now = SecondsSinceUnixEpoch::now(); let session_id = result.session.session_id(); @@ -972,17 +957,17 @@ impl Account { #[cfg(any(test, feature = "testing"))] #[allow(dead_code)] /// Testing only helper to create a session for the given Account - pub async fn create_session_for(&self, other: &Account) -> (Session, Session) { + pub async fn create_session_for(&mut self, other: &mut Account) -> (Session, Session) { use ruma::events::dummy::ToDeviceDummyEventContent; - other.generate_one_time_keys_helper(1).await; - let one_time = other.signed_one_time_keys().await; + other.generate_one_time_keys_helper(1); + let one_time = other.signed_one_time_keys(); - let device = ReadOnlyDevice::from_account(other).await; + let device = ReadOnlyDevice::from_account(other); - let mut our_session = self.create_outbound_session(&device, &one_time).await.unwrap(); + let mut our_session = self.create_outbound_session(&device, &one_time).unwrap(); - other.mark_keys_as_published().await; + other.mark_keys_as_published(); let message = our_session .encrypt( @@ -1016,17 +1001,15 @@ impl Account { panic!("Wrong Olm message type"); }; - let our_device = ReadOnlyDevice::from_account(self).await; - let other_session = other - .create_inbound_session(our_device.curve25519_key().unwrap(), &prekey) - .await - .unwrap(); + let our_device = ReadOnlyDevice::from_account(self); + let other_session = + other.create_inbound_session(our_device.curve25519_key().unwrap(), &prekey).unwrap(); (our_session, other_session.session) } async fn decrypt_olm_helper( - &self, + &mut self, store: &Store, sender: &UserId, sender_key: Curve25519PublicKey, @@ -1052,7 +1035,7 @@ impl Account { #[cfg(feature = "experimental-algorithms")] async fn decrypt_olm_v2( - &self, + &mut self, store: &Store, sender: &UserId, content: &OlmV2Curve25519AesSha2Content, @@ -1062,7 +1045,7 @@ impl Account { #[instrument(skip_all, fields(sender, sender_key = %content.sender_key))] async fn decrypt_olm_v1( - &self, + &mut self, store: &Store, sender: &UserId, content: &OlmV1Curve25519AesSha2Content, @@ -1084,7 +1067,7 @@ impl Account { #[instrument(skip_all, fields(algorithm = ?event.content.algorithm()))] pub(crate) async fn decrypt_to_device_event( - &self, + &mut self, store: &Store, event: &EncryptedToDeviceEvent, ) -> OlmResult { @@ -1110,8 +1093,8 @@ impl Account { } /// Handles a response to a /keys/upload request. - pub async fn receive_keys_upload_response( - &self, + pub fn receive_keys_upload_response( + &mut self, response: &upload_keys::v3::Response, ) -> OlmResult<()> { if !self.shared() { @@ -1122,8 +1105,8 @@ impl Account { debug!("Marking one-time keys as published"); // First mark the current keys as published, as updating the key counts might // generate some new keys if we're still below the limit. - self.mark_keys_as_published().await; - self.update_key_counts(&response.one_time_key_counts, None).await; + self.mark_keys_as_published(); + self.update_key_counts(&response.one_time_key_counts, None); Ok(()) } @@ -1183,7 +1166,7 @@ impl Account { /// Decrypt an Olm message, creating a new Olm session if possible. #[instrument(skip(self, message))] async fn decrypt_olm_message( - &self, + &mut self, store: &Store, sender: &UserId, sender_key: Curve25519PublicKey, @@ -1221,7 +1204,7 @@ impl Account { OlmMessage::PreKey(m) => { // Create the new session. - let result = match self.create_inbound_session(sender_key, m).await { + let result = match self.create_inbound_session(sender_key, m) { Ok(r) => r, Err(_) => { return Err(OlmError::SessionWedged(sender.to_owned(), sender_key)); @@ -1356,6 +1339,16 @@ impl Account { }) } } + + /// Internal use only. + /// + /// Cloning should only be done for testing purposes or when we are certain + /// that we don't want the inner state to be shared. + #[doc(hidden)] + pub fn deep_clone(&self) -> Self { + // `vodozemac::Account` isn't really clonable, but... Don't tell anyone. + Self::from_pickle(self.pickle()).unwrap() + } } impl PartialEq for Account { @@ -1394,14 +1387,14 @@ mod tests { device_id!("DEVICEID") } - #[async_test] - async fn one_time_key_creation() -> Result<()> { - let account = Account::with_device_id(user_id(), device_id()); + #[test] + fn test_one_time_key_creation() -> Result<()> { + let mut account = Account::with_device_id(user_id(), device_id()); - let (_, one_time_keys, _) = account.keys_for_upload().await; + let (_, one_time_keys, _) = account.keys_for_upload(); assert!(!one_time_keys.is_empty()); - let (_, second_one_time_keys, _) = account.keys_for_upload().await; + let (_, second_one_time_keys, _) = account.keys_for_upload(); assert!(!second_one_time_keys.is_empty()); let device_key_ids: BTreeSet<&DeviceKeyId> = @@ -1411,17 +1404,17 @@ mod tests { assert_eq!(device_key_ids, second_device_key_ids); - account.mark_keys_as_published().await; + account.mark_keys_as_published(); account.update_uploaded_key_count(50); - account.generate_one_time_keys().await; + account.generate_one_time_keys(); - let (_, third_one_time_keys, _) = account.keys_for_upload().await; + let (_, third_one_time_keys, _) = account.keys_for_upload(); assert!(third_one_time_keys.is_empty()); account.update_uploaded_key_count(0); - account.generate_one_time_keys().await; + account.generate_one_time_keys(); - let (_, fourth_one_time_keys, _) = account.keys_for_upload().await; + let (_, fourth_one_time_keys, _) = account.keys_for_upload(); assert!(!fourth_one_time_keys.is_empty()); let fourth_device_key_ids: BTreeSet<&DeviceKeyId> = @@ -1431,11 +1424,11 @@ mod tests { Ok(()) } - #[async_test] - async fn fallback_key_creation() -> Result<()> { - let account = Account::with_device_id(user_id(), device_id()); + #[test] + fn test_fallback_key_creation() -> Result<()> { + let mut account = Account::with_device_id(user_id(), device_id()); - let (_, _, fallback_keys) = account.keys_for_upload().await; + let (_, _, fallback_keys) = account.keys_for_upload(); // We don't create fallback keys since we don't know if the server // supports them, we need to receive a sync response to decide if we're @@ -1446,36 +1439,36 @@ mod tests { // A `None` here means that the server doesn't support fallback keys, no // fallback key gets uploaded. - account.update_key_counts(&one_time_keys, None).await; - let (_, _, fallback_keys) = account.keys_for_upload().await; + account.update_key_counts(&one_time_keys, None); + let (_, _, fallback_keys) = account.keys_for_upload(); assert!(fallback_keys.is_empty()); // The empty array means that the server supports fallback keys but // there isn't a unused fallback key on the server. This time we upload // a fallback key. let unused_fallback_keys = &[]; - account.update_key_counts(&one_time_keys, Some(unused_fallback_keys.as_ref())).await; - let (_, _, fallback_keys) = account.keys_for_upload().await; + account.update_key_counts(&one_time_keys, Some(unused_fallback_keys.as_ref())); + let (_, _, fallback_keys) = account.keys_for_upload(); assert!(!fallback_keys.is_empty()); - account.mark_keys_as_published().await; + account.mark_keys_as_published(); // There's an unused fallback key on the server, nothing to do here. let unused_fallback_keys = &[DeviceKeyAlgorithm::SignedCurve25519]; - account.update_key_counts(&one_time_keys, Some(unused_fallback_keys.as_ref())).await; - let (_, _, fallback_keys) = account.keys_for_upload().await; + account.update_key_counts(&one_time_keys, Some(unused_fallback_keys.as_ref())); + let (_, _, fallback_keys) = account.keys_for_upload(); assert!(fallback_keys.is_empty()); Ok(()) } - #[async_test] - async fn fallback_key_signing() -> Result<()> { + #[test] + fn test_fallback_key_signing() -> Result<()> { let key = vodozemac::Curve25519PublicKey::from_base64( "7PUPP6Ijt5R8qLwK2c8uK5hqCNF9tOzWYgGaAay5JBs", )?; let account = Account::with_device_id(user_id(), device_id()); - let key = account.sign_key(key, true).await; + let key = account.sign_key(key, true); let canonical_key = key.to_canonical_json()?; @@ -1488,14 +1481,14 @@ mod tests { .has_signed_raw(key.signatures(), &canonical_key) .expect("Couldn't verify signature"); - let device = ReadOnlyDevice::from_account(&account).await; + let device = ReadOnlyDevice::from_account(&account); device.verify_one_time_key(&key).expect("The device can verify its own signature"); Ok(()) } - #[async_test] - async fn test_account_and_device_creation_timestamp() -> Result<()> { + #[test] + fn test_account_and_device_creation_timestamp() -> Result<()> { let now = MilliSecondsSinceUnixEpoch::now(); let account = Account::with_device_id(user_id(), device_id()); let then = MilliSecondsSinceUnixEpoch::now(); @@ -1503,7 +1496,7 @@ mod tests { assert!(account.creation_local_time() >= now); assert!(account.creation_local_time() <= then); - let device = ReadOnlyDevice::from_account(&account).await; + let device = ReadOnlyDevice::from_account(&account); assert_eq!(account.creation_local_time(), device.first_time_seen_ts()); Ok(()) diff --git a/crates/matrix-sdk-crypto/src/olm/mod.rs b/crates/matrix-sdk-crypto/src/olm/mod.rs index 45e55e2aafd..83630cd4df5 100644 --- a/crates/matrix-sdk-crypto/src/olm/mod.rs +++ b/crates/matrix-sdk-crypto/src/olm/mod.rs @@ -79,28 +79,26 @@ pub(crate) mod tests { device_id!("BOBDEVICE") } - pub(crate) async fn get_account_and_session() -> (Account, Session) { + pub(crate) fn get_account_and_session_test_helper() -> (Account, Session) { let alice = Account::with_device_id(alice_id(), alice_device_id()); - let bob = Account::with_device_id(bob_id(), bob_device_id()); + let mut bob = Account::with_device_id(bob_id(), bob_device_id()); - bob.generate_one_time_keys_helper(1).await; - let one_time_key = *bob.one_time_keys().await.values().next().unwrap(); + bob.generate_one_time_keys_helper(1); + let one_time_key = *bob.one_time_keys().values().next().unwrap(); let sender_key = bob.identity_keys().curve25519; - let session = alice - .create_outbound_session_helper( - SessionConfig::default(), - sender_key, - one_time_key, - false, - ) - .await; + let session = alice.create_outbound_session_helper( + SessionConfig::default(), + sender_key, + one_time_key, + false, + ); (alice, session) } #[test] - fn account_creation() { - let account = Account::with_device_id(alice_id(), alice_device_id()); + fn test_account_creation() { + let mut account = Account::with_device_id(alice_id(), alice_device_id()); assert!(!account.shared()); @@ -108,45 +106,43 @@ pub(crate) mod tests { assert!(account.shared()); } - #[async_test] - async fn one_time_keys_creation() { - let account = Account::with_device_id(alice_id(), alice_device_id()); - let one_time_keys = account.one_time_keys().await; + #[test] + fn test_one_time_keys_creation() { + let mut account = Account::with_device_id(alice_id(), alice_device_id()); + let one_time_keys = account.one_time_keys(); assert!(!one_time_keys.is_empty()); - assert_ne!(account.max_one_time_keys().await, 0); + assert_ne!(account.max_one_time_keys(), 0); - account.generate_one_time_keys_helper(10).await; - let one_time_keys = account.one_time_keys().await; + account.generate_one_time_keys_helper(10); + let one_time_keys = account.one_time_keys(); assert_ne!(one_time_keys.values().len(), 0); assert_ne!(one_time_keys.keys().len(), 0); assert_ne!(one_time_keys.iter().len(), 0); - account.mark_keys_as_published().await; - let one_time_keys = account.one_time_keys().await; + account.mark_keys_as_published(); + let one_time_keys = account.one_time_keys(); assert!(one_time_keys.is_empty()); } #[async_test] - async fn session_creation() { - let alice = Account::with_device_id(alice_id(), alice_device_id()); + async fn test_session_creation() { + let mut alice = Account::with_device_id(alice_id(), alice_device_id()); let bob = Account::with_device_id(bob_id(), bob_device_id()); let alice_keys = alice.identity_keys(); - alice.generate_one_time_keys_helper(1).await; - let one_time_keys = alice.one_time_keys().await; - alice.mark_keys_as_published().await; + alice.generate_one_time_keys_helper(1); + let one_time_keys = alice.one_time_keys(); + alice.mark_keys_as_published(); let one_time_key = *one_time_keys.values().next().unwrap(); - let mut bob_session = bob - .create_outbound_session_helper( - SessionConfig::default(), - alice_keys.curve25519, - one_time_key, - false, - ) - .await; + let mut bob_session = bob.create_outbound_session_helper( + SessionConfig::default(), + alice_keys.curve25519, + one_time_key, + false, + ); let plaintext = "Hello world"; @@ -158,8 +154,7 @@ pub(crate) mod tests { }; let bob_keys = bob.identity_keys(); - let result = - alice.create_inbound_session(bob_keys.curve25519, &prekey_message).await.unwrap(); + let result = alice.create_inbound_session(bob_keys.curve25519, &prekey_message).unwrap(); assert_eq!(bob_session.session_id(), result.session.session_id()); @@ -167,7 +162,7 @@ pub(crate) mod tests { } #[async_test] - async fn group_session_creation() { + async fn test_group_session_creation() { let alice = Account::with_device_id(alice_id(), alice_device_id()); let room_id = room_id!("!test:localhost"); diff --git a/crates/matrix-sdk-crypto/src/olm/signing/mod.rs b/crates/matrix-sdk-crypto/src/olm/signing/mod.rs index 5f1ac9a3905..b36130cbd53 100644 --- a/crates/matrix-sdk-crypto/src/olm/signing/mod.rs +++ b/crates/matrix-sdk-crypto/src/olm/signing/mod.rs @@ -513,7 +513,6 @@ impl PrivateCrossSigningIdentity { account .sign_cross_signing_key(&mut public_key) - .await .expect("Can't sign our freshly created master key with our account"); master.public_key = public_key @@ -672,7 +671,7 @@ mod tests { } #[test] - fn signature_verification() { + fn test_signature_verification() { let signing = Signing::new(); let user_id = user_id(); let key_id = DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, "DEVICEID".into()); @@ -696,7 +695,7 @@ mod tests { } #[test] - fn pickling_signing() { + fn test_pickling_signing() { let signing = Signing::new(); let pickled = signing.pickle(); @@ -706,7 +705,7 @@ mod tests { } #[async_test] - async fn private_identity_creation() { + async fn test_private_identity_creation() { let identity = PrivateCrossSigningIdentity::new(user_id().to_owned()).await; let master_key = identity.master_key.lock().await; @@ -724,7 +723,7 @@ mod tests { } #[async_test] - async fn identity_pickling() { + async fn test_identity_pickling() { let identity = PrivateCrossSigningIdentity::new(user_id().to_owned()).await; let pickled = identity.pickle().await; @@ -744,7 +743,7 @@ mod tests { } #[async_test] - async fn private_identity_signed_by_account() { + async fn test_private_identity_signed_by_account() { let account = Account::with_device_id(user_id(), device_id!("DEVICEID")); let (identity, _, _) = PrivateCrossSigningIdentity::with_account(&account).await; let master = identity.master_key.lock().await; @@ -767,11 +766,11 @@ mod tests { } #[async_test] - async fn sign_device() { + async fn test_sign_device() { let account = Account::with_device_id(user_id(), device_id!("DEVICEID")); let (identity, _, _) = PrivateCrossSigningIdentity::with_account(&account).await; - let mut device = ReadOnlyDevice::from_account(&account).await; + let mut device = ReadOnlyDevice::from_account(&account); let self_signing = identity.self_signing_key.lock().await; let self_signing = self_signing.as_ref().unwrap(); @@ -784,7 +783,7 @@ mod tests { } #[async_test] - async fn sign_user_identity() { + async fn test_sign_user_identity() { let account = Account::with_device_id(user_id(), device_id!("DEVICEID")); let (identity, _, _) = PrivateCrossSigningIdentity::with_account(&account).await; diff --git a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs index 9aef9287323..d59764b7ffa 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs @@ -969,7 +969,7 @@ mod tests { .expect("Can't parse the keys claim response") } - async fn machine_with_user(user_id: &UserId, device_id: &DeviceId) -> OlmMachine { + async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine { let keys_query = keys_query_response(); let keys_claim = keys_claim_response(); let txn_id = TransactionId::new(); @@ -985,10 +985,10 @@ mod tests { } async fn machine() -> OlmMachine { - machine_with_user(alice_id(), alice_device_id()).await + machine_with_user_test_helper(alice_id(), alice_device_id()).await } - async fn machine_with_shared_room_key() -> OlmMachine { + async fn machine_with_shared_room_key_test_helper() -> OlmMachine { let machine = machine().await; let room_id = room_id!("!test:localhost"); let keys_claim = keys_claim_response(); @@ -1119,7 +1119,7 @@ mod tests { #[async_test] async fn ratcheted_sharing() { - let machine = machine_with_shared_room_key().await; + let machine = machine_with_shared_room_key_test_helper().await; let room_id = room_id!("!test:localhost"); let late_joiner = user_id!("@bob:localhost"); @@ -1147,7 +1147,7 @@ mod tests { #[async_test] async fn changing_encryption_settings() { - let machine = machine_with_shared_room_key().await; + let machine = machine_with_shared_room_key_test_helper().await; let room_id = room_id!("!test:localhost"); let keys_claim = keys_claim_response(); @@ -1201,7 +1201,7 @@ mod tests { let device_id = device_id!("TESTDEVICE"); let room_id = room_id!("!test:localhost"); - let machine = machine_with_user(user_id, device_id).await; + let machine = machine_with_user_test_helper(user_id, device_id).await; let (outbound, _) = machine .inner @@ -1343,7 +1343,7 @@ mod tests { } #[async_test] - async fn no_olm_withheld_only_sent_once() { + async fn test_no_olm_withheld_only_sent_once() { let keys_query = keys_query_response(); let txn_id = TransactionId::new(); diff --git a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs index 413961995d3..626c1c0471e 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs @@ -390,7 +390,7 @@ impl SessionManager { }; let account = store_transaction.account().await?; - let session = match account.create_outbound_session(&device, key_map).await { + let session = match account.create_outbound_session(&device, key_map) { Ok(s) => s, Err(e) => { warn!( @@ -438,7 +438,8 @@ impl SessionManager { } } - match self.key_request_machine.collect_incoming_key_requests().await { + let store_cache = self.store.cache().await?; + match self.key_request_machine.collect_incoming_key_requests(&store_cache).await { Ok(sessions) => { let changes = Changes { sessions, ..Default::default() }; self.store.save_changes(changes).await? @@ -484,7 +485,7 @@ mod tests { identities::{IdentityManager, ReadOnlyDevice}, olm::{Account, PrivateCrossSigningIdentity}, session_manager::GroupSessionCache, - store::{CryptoStoreWrapper, MemoryStore, Store}, + store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store}, verification::VerificationMachine, }; @@ -531,7 +532,6 @@ mod tests { let user_id = user_id(); let device_id = device_id(); - let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new())); let account = Account::with_device_id(user_id, device_id); let store = Arc::new(CryptoStoreWrapper::new(user_id, MemoryStore::new())); let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id))); @@ -541,15 +541,12 @@ mod tests { store.clone(), ); - let store = Store::new(account.clone(), identity, store, verification); - { - // Perform a dummy transaction to sync in-memory cache with the db. - let tr = store.transaction().await.unwrap(); - tr.commit().await.unwrap(); - } + let store = Store::new(account.static_data().clone(), identity, store, verification); + store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); let session_cache = GroupSessionCache::new(store.clone()); + let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new())); let key_request = GossipMachine::new(store.clone(), session_cache, users_for_key_claim.clone()); @@ -557,11 +554,11 @@ mod tests { } #[async_test] - async fn session_creation() { + async fn test_session_creation() { let manager = session_manager_test_helper().await; - let bob = bob_account(); + let mut bob = bob_account(); - let bob_device = ReadOnlyDevice::from_account(&bob).await; + let bob_device = ReadOnlyDevice::from_account(&bob); manager.store.save_devices(&[bob_device]).await.unwrap(); @@ -570,10 +567,10 @@ mod tests { assert!(request.one_time_keys.contains_key(bob.user_id())); - bob.generate_one_time_keys_helper(1).await; - let one_time = bob.signed_one_time_keys().await; + bob.generate_one_time_keys_helper(1); + let one_time = bob.signed_one_time_keys(); assert!(!one_time.is_empty()); - bob.mark_keys_as_published().await; + bob.mark_keys_as_published(); let mut one_time_keys = BTreeMap::new(); one_time_keys @@ -601,8 +598,11 @@ mod tests { // now bob turns up, and we start tracking his devices... let bob = bob_account(); - let bob_device = ReadOnlyDevice::from_account(&bob).await; - manager.store.update_tracked_users(iter::once(bob.user_id())).await.unwrap(); + let bob_device = ReadOnlyDevice::from_account(&bob); + { + let cache = manager.store.cache().await.unwrap(); + cache.update_tracked_users(&manager.store, iter::once(bob.user_id())).await.unwrap(); + } // ... and start off an attempt to get the missing sessions. This should block // for now. @@ -652,12 +652,19 @@ mod tests { use ruma::SecondsSinceUnixEpoch; let manager = session_manager_test_helper().await; - let bob = bob_account(); - - let manager_account = &manager.store.cache().await.unwrap().account; - let (_, mut session) = bob.create_session_for(manager_account).await; + let mut bob = bob_account(); + + let (_, mut session) = manager + .store + .with_transaction(|mut tr| async { + let manager_account = tr.account().await.unwrap(); + let res = bob.create_session_for(manager_account).await; + Ok((tr, res)) + }) + .await + .unwrap(); - let bob_device = ReadOnlyDevice::from_account(&bob).await; + let bob_device = ReadOnlyDevice::from_account(&bob); let time = SystemTime::now() - Duration::from_secs(3601); session.creation_time = SecondsSinceUnixEpoch::from_system_time(time).unwrap(); @@ -679,10 +686,10 @@ mod tests { assert!(request.one_time_keys.contains_key(bob.user_id())); - bob.generate_one_time_keys_helper(1).await; - let one_time = bob.signed_one_time_keys().await; + bob.generate_one_time_keys_helper(1); + let one_time = bob.signed_one_time_keys(); assert!(!one_time.is_empty()); - bob.mark_keys_as_published().await; + bob.mark_keys_as_published(); let mut one_time_keys = BTreeMap::new(); one_time_keys @@ -705,7 +712,7 @@ mod tests { async fn failure_handling() { let alice = user_id!("@alice:example.org"); let alice_account = Account::with_device_id(alice, "DEVICEID".into()); - let alice_device = ReadOnlyDevice::from_account(&alice_account).await; + let alice_device = ReadOnlyDevice::from_account(&alice_account); let manager = session_manager_test_helper().await; @@ -747,8 +754,8 @@ mod tests { let response = KeyClaimResponse::try_from_http_response(response).unwrap(); let alice = user_id!("@alice:example.org"); - let alice_account = Account::with_device_id(alice, "DEVICEID".into()); - let alice_device = ReadOnlyDevice::from_account(&alice_account).await; + let mut alice_account = Account::with_device_id(alice, "DEVICEID".into()); + let alice_device = ReadOnlyDevice::from_account(&alice_account); let manager = session_manager_test_helper().await; manager.store.save_devices(&[alice_device]).await.unwrap(); @@ -765,8 +772,8 @@ mod tests { // Since alice is timed out, we won't claim keys for her. assert!(manager.get_missing_sessions(iter::once(alice)).await.unwrap().is_none()); - alice_account.generate_one_time_keys_helper(1).await; - let one_time = alice_account.signed_one_time_keys().await; + alice_account.generate_one_time_keys_helper(1); + let one_time = alice_account.signed_one_time_keys(); assert!(!one_time.is_empty()); let mut one_time_keys = BTreeMap::new(); diff --git a/crates/matrix-sdk-crypto/src/store/caches.rs b/crates/matrix-sdk-crypto/src/store/caches.rs index 2c3491fd91a..5535eda693b 100644 --- a/crates/matrix-sdk-crypto/src/store/caches.rs +++ b/crates/matrix-sdk-crypto/src/store/caches.rs @@ -397,12 +397,12 @@ mod tests { use super::{DeviceStore, GroupSessionStore, SequenceNumber, SessionStore}; use crate::{ identities::device::testing::get_device, - olm::{tests::get_account_and_session, InboundGroupSession}, + olm::{tests::get_account_and_session_test_helper, InboundGroupSession}, }; #[async_test] async fn test_session_store() { - let (_, session) = get_account_and_session().await; + let (_, session) = get_account_and_session_test_helper(); let store = SessionStore::new(); @@ -419,7 +419,7 @@ mod tests { #[async_test] async fn test_session_store_bulk_storing() { - let (_, session) = get_account_and_session().await; + let (_, session) = get_account_and_session_test_helper(); let store = SessionStore::new(); store.set_for_sender(&session.sender_key.to_base64(), vec![session.clone()]); @@ -434,7 +434,7 @@ mod tests { #[async_test] async fn test_group_session_store() { - let (account, _) = get_account_and_session().await; + let (account, _) = get_account_and_session_test_helper(); let room_id = room_id!("!test:localhost"); let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw"; diff --git a/crates/matrix-sdk-crypto/src/store/integration_tests.rs b/crates/matrix-sdk-crypto/src/store/integration_tests.rs index 50b4087df59..37750d9ace0 100644 --- a/crates/matrix-sdk-crypto/src/store/integration_tests.rs +++ b/crates/matrix-sdk-crypto/src/store/integration_tests.rs @@ -66,7 +66,7 @@ macro_rules! cryptostore_integration_tests { let store = get_store(name, None).await; let account = get_account(); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.deep_clone()), }).await.expect("Can't save account"); (account, store) } @@ -77,10 +77,10 @@ macro_rules! cryptostore_integration_tests { async fn get_account_and_session() -> (Account, Session) { let alice = Account::with_device_id(alice_id(), alice_device_id()); - let bob = Account::with_device_id(bob_id(), bob_device_id()); + let mut bob = Account::with_device_id(bob_id(), bob_device_id()); - bob.generate_one_time_keys_helper(1).await; - let one_time_key = *bob.one_time_keys().await.values().next().unwrap(); + bob.generate_one_time_keys_helper(1); + let one_time_key = *bob.one_time_keys().values().next().unwrap(); let sender_key = bob.identity_keys().curve25519; let session = alice .create_outbound_session_helper( @@ -88,8 +88,7 @@ macro_rules! cryptostore_integration_tests { sender_key, one_time_key, false, - ) - .await; + ); (alice, session) } @@ -124,7 +123,7 @@ macro_rules! cryptostore_integration_tests { let store = get_store("load_account", None).await; let account = get_account(); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.deep_clone()), }).await.expect("Can't save account"); let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = loaded_account.unwrap(); @@ -138,7 +137,7 @@ macro_rules! cryptostore_integration_tests { get_store("load_account_with_passphrase", Some("secret_passphrase")).await; let account = get_account(); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.deep_clone()), }).await.expect("Can't save account"); let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = loaded_account.unwrap(); @@ -149,14 +148,14 @@ macro_rules! cryptostore_integration_tests { #[async_test] async fn save_and_share_account() { let store = get_store("save_and_share_account", None).await; - let account = get_account(); + let mut account = get_account(); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.deep_clone()), }).await.expect("Can't save account"); account.mark_as_shared(); account.update_uploaded_key_count(50); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.deep_clone()), }).await.expect("Can't save account"); let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = loaded_account.unwrap(); @@ -169,7 +168,7 @@ macro_rules! cryptostore_integration_tests { async fn load_sessions() { let store = get_store("load_sessions", None).await; let (account, session) = get_account_and_session().await; - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.deep_clone()), }).await.expect("Can't save account"); let changes = Changes { sessions: vec![session.clone()], ..Default::default() }; @@ -193,7 +192,7 @@ macro_rules! cryptostore_integration_tests { let sender_key = session.sender_key.to_base64(); let session_id = session.session_id().to_owned(); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.deep_clone()), }).await.expect("Can't save account"); let changes = Changes { sessions: vec![session.clone()], ..Default::default() }; store.save_changes(changes).await.unwrap(); @@ -399,21 +398,19 @@ macro_rules! cryptostore_integration_tests { } #[async_test] - async fn device_saving() { + async fn test_device_saving() { let dir = "device_saving"; let (_account, store) = get_loaded_store(dir.clone()).await; let alice_device_1 = ReadOnlyDevice::from_account(&Account::with_device_id( "@alice:localhost".try_into().unwrap(), "FIRSTDEVICE".into(), - )) - .await; + )); let alice_device_2 = ReadOnlyDevice::from_account(&Account::with_device_id( "@alice:localhost".try_into().unwrap(), "SECONDDEVICE".into(), - )) - .await; + )); let changes = Changes { devices: DeviceChanges { @@ -491,7 +488,7 @@ macro_rules! cryptostore_integration_tests { let account = Account::with_device_id(&user_id, device_id); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }) + store.save_pending_changes(PendingChanges { account: Some(account), }) .await .expect("Can't save account"); diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index 65ecf8eae62..163d93b6b5f 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -52,6 +52,7 @@ fn encode_key_info(info: &SecretInfo) -> String { /// An in-memory only store that will forget all the E2EE key once it's dropped. #[derive(Debug)] pub struct MemoryStore { + account: StdRwLock>, sessions: SessionStore, inbound_group_sessions: GroupSessionStore, olm_hashes: StdRwLock>>, @@ -70,6 +71,7 @@ pub struct MemoryStore { impl Default for MemoryStore { fn default() -> Self { MemoryStore { + account: Default::default(), sessions: SessionStore::new(), inbound_group_sessions: GroupSessionStore::new(), olm_hashes: Default::default(), @@ -126,7 +128,7 @@ impl CryptoStore for MemoryStore { type Error = Infallible; async fn load_account(&self) -> Result> { - Ok(None) + Ok(self.account.read().unwrap().as_ref().map(|acc| acc.deep_clone())) } async fn load_identity(&self) -> Result> { @@ -137,8 +139,11 @@ impl CryptoStore for MemoryStore { Ok(self.next_batch_token.read().await.clone()) } - async fn save_pending_changes(&self, _changes: PendingChanges) -> Result<()> { - // TODO(bnjbvr) why didn't save_changes save the account? + async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> { + if let Some(account) = changes.account { + *self.account.write().unwrap() = Some(account); + } + Ok(()) } @@ -440,13 +445,13 @@ mod tests { use crate::{ identities::device::testing::get_device, - olm::{tests::get_account_and_session, InboundGroupSession, OlmMessageHash}, + olm::{tests::get_account_and_session_test_helper, InboundGroupSession, OlmMessageHash}, store::{memorystore::MemoryStore, Changes, CryptoStore, PendingChanges}, }; #[async_test] async fn test_session_store() { - let (account, session) = get_account_and_session().await; + let (account, session) = get_account_and_session_test_helper(); let store = MemoryStore::new(); assert!(store.load_account().await.unwrap().is_none()); @@ -464,7 +469,7 @@ mod tests { #[async_test] async fn test_group_session_store() { - let (account, _) = get_account_and_session().await; + let (account, _) = get_account_and_session_test_helper(); let room_id = room_id!("!test:localhost"); let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw"; diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index 7c7ae28b91f..8ebcd6004cb 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -55,7 +55,7 @@ use ruma::{ }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use thiserror::Error; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::{Mutex, MutexGuard, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock}; use tracing::{info, warn}; use vodozemac::{base64_encode, megolm::SessionOrdering, Curve25519PublicKey}; use zeroize::Zeroize; @@ -107,14 +107,180 @@ pub struct Store { #[derive(Debug)] pub(crate) struct StoreCache { + store: Arc, + tracked_users: StdRwLock>, tracked_user_loading_lock: RwLock, - pub account: Account, + account: Mutex>, +} + +impl StoreCache { + /// Returns a reference to the `Account.` + /// + /// Either load the account from the cache, or the store, if missing from + /// the cache. + /// + /// Note there should always be an account stored at least in the store, so + /// this doesn't return an `Option`. + pub async fn account(&self) -> Result + '_> { + let mut guard = self.account.lock().await; + if guard.is_some() { + Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap())) + } else { + match self.store.load_account().await? { + Some(account) => { + *guard = Some(account); + Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap())) + } + None => Err(CryptoStoreError::AccountUnset), + } + } + } + + /// Load the list of users for whom we are tracking their device lists and + /// fill out our caches. + /// + /// This method ensures that we're only going to load the users from the + /// actual [`CryptoStore`] once, it will also make sure that any + /// concurrent calls to this method get deduplicated. + async fn ensure_sync_tracked_users(&self, store: &Store) -> Result<()> { + // Check if the users are loaded, and in that case do nothing. + let loaded = self.tracked_user_loading_lock.read().await; + if *loaded { + return Ok(()); + } + + // Otherwise, we may load the users. + drop(loaded); + let mut loaded = self.tracked_user_loading_lock.write().await; + + // Check again if the users have been loaded, in case another call to this + // method loaded the tracked users between the time we tried to + // acquire the lock and the time we actually acquired the lock. + if *loaded { + return Ok(()); + } + + let tracked_users = store.inner.store.load_tracked_users().await?; + + let mut query_users_lock = store.inner.users_for_key_query.lock().await; + let mut tracked_users_cache = self.tracked_users.write().unwrap(); + for user in tracked_users { + tracked_users_cache.insert(user.user_id.to_owned()); + + if user.dirty { + query_users_lock.insert_user(&user.user_id); + } + } + + *loaded = true; + + Ok(()) + } + + /// Process notifications that users have changed devices. + /// + /// This is used to handle the list of device-list updates that is received + /// from the `/sync` response. Any users *whose device lists we are + /// tracking* are flagged as needing a key query. Users whose devices we + /// are not tracking are ignored. + pub(crate) async fn mark_tracked_users_as_changed( + &self, + store: &Store, + users: impl Iterator, + ) -> Result<()> { + let mut store_updates: Vec<(&UserId, bool)> = Vec::new(); + let mut key_query_lock = store.inner.users_for_key_query.lock().await; + + { + let tracked_users = &self.tracked_users.read().unwrap(); + for user_id in users { + if tracked_users.contains(user_id) { + key_query_lock.insert_user(user_id); + store_updates.push((user_id, true)); + } + } + } + + store.inner.store.save_tracked_users(&store_updates).await + } + + /// Mark the given user as being tracked for device lists, and mark that it + /// has an outdated device list. + /// + /// This means that the user will be considered for a `/keys/query` request + /// next time [`Store::users_for_key_query()`] is called. + pub(crate) async fn mark_user_as_changed(&self, store: &Store, user: &UserId) -> Result<()> { + store.inner.users_for_key_query.lock().await.insert_user(user); + self.tracked_users.write().unwrap().insert(user.to_owned()); + + store.inner.store.save_tracked_users(&[(user, true)]).await + } + + /// Add entries to the list of users being tracked for device changes + /// + /// Any users not already on the list are flagged as awaiting a key query. + /// Users that were already in the list are unaffected. + pub(crate) async fn update_tracked_users( + &self, + store: &Store, + users: impl Iterator, + ) -> Result<()> { + let mut store_updates = Vec::new(); + let mut key_query_lock = store.inner.users_for_key_query.lock().await; + + { + let mut tracked_users = self.tracked_users.write().unwrap(); + for user_id in users { + if tracked_users.insert(user_id.to_owned()) { + key_query_lock.insert_user(user_id); + store_updates.push((user_id, true)) + } + } + } + + store.inner.store.save_tracked_users(&store_updates).await + } + + /// See the docs for [`crate::OlmMachine::tracked_users()`]. + pub(crate) fn tracked_users(&self) -> HashSet { + self.tracked_users.read().unwrap().iter().cloned().collect() + } + + /// Flag that the given users devices are now up-to-date. + /// + /// This is called after processing the response to a /keys/query request. + /// Any users whose device lists we are tracking are removed from the + /// list of those pending a /keys/query. + pub(crate) async fn mark_tracked_users_as_up_to_date( + &self, + store: &Store, + users: impl Iterator, + sequence_number: SequenceNumber, + ) -> Result<()> { + let mut store_updates: Vec<(&UserId, bool)> = Vec::new(); + let mut key_query_lock = store.inner.users_for_key_query.lock().await; + + { + let tracked_users = self.tracked_users.read().unwrap(); + for user_id in users { + if tracked_users.contains(user_id) { + let clean = key_query_lock.maybe_remove_user(user_id, sequence_number); + store_updates.push((user_id, !clean)); + } + } + } + + store.inner.store.save_tracked_users(&store_updates).await?; + // wake up any tasks that may have been waiting for updates + store.inner.users_for_key_query_condvar.notify_all(); + + Ok(()) + } } pub(crate) struct StoreCacheGuard { - cache: Arc, - //cache: OwnedRwLockReadGuard, + cache: OwnedRwLockReadGuard, // TODO: (bnjbvr, #2624) add cross-process lock guard here. } @@ -132,15 +298,23 @@ pub struct StoreTransaction { store: Store, changes: PendingChanges, // TODO hold onto the cross-process crypto store lock + cache. - cache: Arc, - //cache: OwnedRwLockWriteGuard, + cache: OwnedRwLockWriteGuard, } impl StoreTransaction { /// Starts a new `StoreTransaction`. pub async fn new(store: Store) -> Result { let cache = store.inner.cache.clone(); - Ok(Self { store, changes: PendingChanges::default(), cache }) + + Ok(Self { + store, + changes: PendingChanges::default(), + cache: cache.clone().write_owned().await, + }) + } + + pub(crate) fn cache(&self) -> &StoreCache { + &self.cache } /// Returns a reference to the current `Store`. @@ -151,7 +325,9 @@ impl StoreTransaction { /// Gets a `Account` for update. pub async fn account(&mut self) -> Result<&mut Account> { if self.changes.account.is_none() { - self.changes.account = Some(self.cache.account.clone()); + // Make sure the cache loaded the account. + let _ = self.cache.account().await?; + self.changes.account = self.cache.account.lock().await.take(); } Ok(self.changes.account.as_mut().unwrap()) } @@ -159,11 +335,19 @@ impl StoreTransaction { /// Commits all dirty fields to the store, and maintains the cache so it /// reflects the current state of the database. pub async fn commit(self) -> Result<()> { + if self.changes.is_empty() { + return Ok(()); + } + // Save changes in the database. + let account = self.changes.account.as_ref().map(|acc| acc.deep_clone()); + self.store.save_pending_changes(self.changes).await?; // Make the cache coherent with the database. - // for changes.account: nothing to do, it's the same underlying shared account. + if let Some(account) = account { + *self.cache.account.lock().await = Some(account); + } Ok(()) } @@ -177,8 +361,7 @@ struct StoreInner { /// In-memory cache for the current crypto store. /// /// ⚠ Must remain private. - // TODO: (bnjbvr, #2624) add RwLock here - cache: Arc, + cache: Arc>, verification_machine: VerificationMachine, @@ -592,24 +775,25 @@ impl From<&InboundGroupSession> for RoomKeyInfo { impl Store { /// Create a new Store. pub(crate) fn new( - account: Account, + account: StaticAccountData, identity: Arc>, store: Arc, verification_machine: VerificationMachine, ) -> Self { Self { inner: Arc::new(StoreInner { - static_account: account.static_data().clone(), + static_account: account, identity, - store, + store: store.clone(), verification_machine, users_for_key_query: AsyncStdMutex::new(UsersForKeyQuery::new()), users_for_key_query_condvar: Condvar::new(), - cache: Arc::new(StoreCache { + cache: Arc::new(RwLock::new(StoreCache { + store, tracked_users: Default::default(), tracked_user_loading_lock: Default::default(), - account, - }), + account: Default::default(), + })), }), } } @@ -635,10 +819,10 @@ impl Store { // - if acquired, look if another process touched the underlying storage, // - if yes, reload everything; if no, return current cache - let cache = StoreCacheGuard { cache: self.inner.cache.clone() }; + let cache = StoreCacheGuard { cache: self.inner.cache.clone().read_owned().await }; // Make sure tracked users are always up to date. - self.ensure_sync_tracked_users(&cache).await?; + cache.ensure_sync_tracked_users(self).await?; Ok(cache) } @@ -981,143 +1165,6 @@ impl Store { Ok(()) } - /// Mark the given user as being tracked for device lists, and mark that it - /// has an outdated device list. - /// - /// This means that the user will be considered for a `/keys/query` request - /// next time [`Store::users_for_key_query()`] is called. - pub(crate) async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> { - self.inner.users_for_key_query.lock().await.insert_user(user); - self.cache().await?.tracked_users.write().unwrap().insert(user.to_owned()); - - self.inner.store.save_tracked_users(&[(user, true)]).await - } - - /// Add entries to the list of users being tracked for device changes - /// - /// Any users not already on the list are flagged as awaiting a key query. - /// Users that were already in the list are unaffected. - pub(crate) async fn update_tracked_users( - &self, - users: impl Iterator, - ) -> Result<()> { - let cache = self.cache().await?; - - let mut store_updates = Vec::new(); - let mut key_query_lock = self.inner.users_for_key_query.lock().await; - - { - let mut tracked_users = cache.tracked_users.write().unwrap(); - for user_id in users { - if tracked_users.insert(user_id.to_owned()) { - key_query_lock.insert_user(user_id); - store_updates.push((user_id, true)) - } - } - } - - self.inner.store.save_tracked_users(&store_updates).await - } - - /// Process notifications that users have changed devices. - /// - /// This is used to handle the list of device-list updates that is received - /// from the `/sync` response. Any users *whose device lists we are - /// tracking* are flagged as needing a key query. Users whose devices we - /// are not tracking are ignored. - pub(crate) async fn mark_tracked_users_as_changed( - &self, - users: impl Iterator, - ) -> Result<()> { - let cache = self.cache().await?; - - let mut store_updates: Vec<(&UserId, bool)> = Vec::new(); - let mut key_query_lock = self.inner.users_for_key_query.lock().await; - - { - let tracked_users = &cache.tracked_users.read().unwrap(); - for user_id in users { - if tracked_users.contains(user_id) { - key_query_lock.insert_user(user_id); - store_updates.push((user_id, true)); - } - } - } - - self.inner.store.save_tracked_users(&store_updates).await - } - - /// Flag that the given users devices are now up-to-date. - /// - /// This is called after processing the response to a /keys/query request. - /// Any users whose device lists we are tracking are removed from the - /// list of those pending a /keys/query. - pub(crate) async fn mark_tracked_users_as_up_to_date( - &self, - users: impl Iterator, - sequence_number: SequenceNumber, - ) -> Result<()> { - let mut store_updates: Vec<(&UserId, bool)> = Vec::new(); - let mut key_query_lock = self.inner.users_for_key_query.lock().await; - - { - let cache = self.cache().await?; - let tracked_users = cache.tracked_users.read().unwrap(); - for user_id in users { - if tracked_users.contains(user_id) { - let clean = key_query_lock.maybe_remove_user(user_id, sequence_number); - store_updates.push((user_id, !clean)); - } - } - } - self.inner.store.save_tracked_users(&store_updates).await?; - // wake up any tasks that may have been waiting for updates - self.inner.users_for_key_query_condvar.notify_all(); - - Ok(()) - } - - /// Load the list of users for whom we are tracking their device lists and - /// fill out our caches. - /// - /// This method ensures that we're only going to load the users from the - /// actual [`CryptoStore`] once, it will also make sure that any - /// concurrent calls to this method get deduplicated. - async fn ensure_sync_tracked_users(&self, cache: &StoreCacheGuard) -> Result<()> { - // Check if the users are loaded, and in that case do nothing. - let loaded = cache.tracked_user_loading_lock.read().await; - if *loaded { - return Ok(()); - } - - // Otherwise, we may load the users. - drop(loaded); - let mut loaded = cache.tracked_user_loading_lock.write().await; - - // Check again if the users have been loaded, in case another call to this - // method loaded the tracked users between the time we tried to - // acquire the lock and the time we actually acquired the lock. - if *loaded { - return Ok(()); - } - - let tracked_users = self.inner.store.load_tracked_users().await?; - - let mut query_users_lock = self.inner.users_for_key_query.lock().await; - let mut tracked_users_cache = cache.tracked_users.write().unwrap(); - for user in tracked_users { - tracked_users_cache.insert(user.user_id.to_owned()); - - if user.dirty { - query_users_lock.insert_user(&user.user_id); - } - } - - *loaded = true; - - Ok(()) - } - /// Get the set of users that has the outdate/dirty flag set for their list /// of devices. /// @@ -1132,9 +1179,6 @@ impl Store { pub(crate) async fn users_for_key_query( &self, ) -> Result<(HashSet, SequenceNumber)> { - // Make sure the tracked users set is up to date. - let _cache = self.cache().await?; - Ok(self.inner.users_for_key_query.lock().await.users_for_key_query()) } @@ -1175,11 +1219,6 @@ impl Store { } } - /// See the docs for [`crate::OlmMachine::tracked_users()`]. - pub(crate) async fn tracked_users(&self) -> Result> { - Ok(self.cache().await?.tracked_users.read().unwrap().iter().cloned().collect()) - } - /// Check whether there is a global flag to only encrypt messages for /// trusted devices or for everyone. pub async fn get_only_allow_trusted_devices(&self) -> Result { diff --git a/crates/matrix-sdk-crypto/src/verification/mod.rs b/crates/matrix-sdk-crypto/src/verification/mod.rs index 3e4a80296d3..102f85f2345 100644 --- a/crates/matrix-sdk-crypto/src/verification/mod.rs +++ b/crates/matrix-sdk-crypto/src/verification/mod.rs @@ -842,8 +842,8 @@ pub(crate) mod tests { let bob_readonly_identity = ReadOnlyOwnUserIdentity::from_private(&*bob_private_identity.lock().await).await; - let alice_device = ReadOnlyDevice::from_account(&alice).await; - let bob_device = ReadOnlyDevice::from_account(&bob).await; + let alice_device = ReadOnlyDevice::from_account(&alice); + let bob_device = ReadOnlyDevice::from_account(&bob); let alice_changes = Changes { identities: IdentityChanges { diff --git a/crates/matrix-sdk-crypto/src/verification/qrcode.rs b/crates/matrix-sdk-crypto/src/verification/qrcode.rs index 54719f79d27..7c1d5270449 100644 --- a/crates/matrix-sdk-crypto/src/verification/qrcode.rs +++ b/crates/matrix-sdk-crypto/src/verification/qrcode.rs @@ -929,7 +929,7 @@ mod tests { let flow_id = FlowId::ToDevice("test_transaction".into()); let device_key = account.static_data.identity_keys.ed25519; - let alice_device = ReadOnlyDevice::from_account(&account).await; + let alice_device = ReadOnlyDevice::from_account(&account); let identities = store.get_identities(alice_device).await.unwrap(); @@ -998,8 +998,8 @@ mod tests { let master_key = private_identity.master_public_key().await.unwrap(); let master_key = master_key.get_first_key().unwrap().to_owned(); - let alice_device = ReadOnlyDevice::from_account(&alice_account).await; - let bob_device = ReadOnlyDevice::from_account(&bob_account).await; + let alice_device = ReadOnlyDevice::from_account(&alice_account); + let bob_device = ReadOnlyDevice::from_account(&bob_account); let mut changes = Changes::default(); changes.identities.new.push(identity.clone().into()); diff --git a/crates/matrix-sdk-crypto/src/verification/requests.rs b/crates/matrix-sdk-crypto/src/verification/requests.rs index 4214fccda58..13718e00cbe 100644 --- a/crates/matrix-sdk-crypto/src/verification/requests.rs +++ b/crates/matrix-sdk-crypto/src/verification/requests.rs @@ -1705,7 +1705,7 @@ mod tests { // test what happens when we cancel() a request that we have just received over // to-device messages. let (_alice, alice_store, bob, bob_store) = setup_stores().await; - let bob_device = ReadOnlyDevice::from_account(&bob).await; + let bob_device = ReadOnlyDevice::from_account(&bob); // Set up the pair of verification requests let bob_request = build_test_request(&bob_store, alice_id(), None); @@ -1742,7 +1742,7 @@ mod tests { let room_id = room_id!("!test:localhost"); let (_alice, alice_store, bob, bob_store) = setup_stores().await; - let bob_device = ReadOnlyDevice::from_account(&bob).await; + let bob_device = ReadOnlyDevice::from_account(&bob); let content = VerificationRequest::request( &bob_store.account.user_id, @@ -1797,7 +1797,7 @@ mod tests { #[async_test] async fn test_requesting_until_sas_to_device() { let (_alice, alice_store, bob, bob_store) = setup_stores().await; - let bob_device = ReadOnlyDevice::from_account(&bob).await; + let bob_device = ReadOnlyDevice::from_account(&bob); // Set up the pair of verification requests let bob_request = build_test_request(&bob_store, alice_id(), None); diff --git a/crates/matrix-sdk-crypto/src/verification/sas/mod.rs b/crates/matrix-sdk-crypto/src/verification/sas/mod.rs index 9df4cdf91c7..c85f19a9a88 100644 --- a/crates/matrix-sdk-crypto/src/verification/sas/mod.rs +++ b/crates/matrix-sdk-crypto/src/verification/sas/mod.rs @@ -888,13 +888,13 @@ mod tests { device_id!("BOBDEVCIE") } - async fn machine_pair() -> (VerificationStore, ReadOnlyDevice, VerificationStore, ReadOnlyDevice) - { + fn machine_pair_test_helper( + ) -> (VerificationStore, ReadOnlyDevice, VerificationStore, ReadOnlyDevice) { let alice = Account::with_device_id(alice_id(), alice_device_id()); - let alice_device = ReadOnlyDevice::from_account(&alice).await; + let alice_device = ReadOnlyDevice::from_account(&alice); let bob = Account::with_device_id(bob_id(), bob_device_id()); - let bob_device = ReadOnlyDevice::from_account(&bob).await; + let bob_device = ReadOnlyDevice::from_account(&bob); let alice_store = VerificationStore { account: alice.static_data.clone(), @@ -916,7 +916,7 @@ mod tests { #[async_test] async fn sas_wrapper_full() { - let (alice_store, alice_device, bob_store, bob_device) = machine_pair().await; + let (alice_store, alice_device, bob_store, bob_device) = machine_pair_test_helper(); let identities = alice_store.get_identities(bob_device).await.unwrap(); @@ -990,7 +990,7 @@ mod tests { #[async_test] async fn sas_with_restricted_methods() { - let (alice_store, alice_device, bob_store, bob_device) = machine_pair().await; + let (alice_store, alice_device, bob_store, bob_device) = machine_pair_test_helper(); let identities = alice_store.get_identities(bob_device).await.unwrap(); let short_auth_strings = vec![ShortAuthenticationString::Decimal]; diff --git a/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs b/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs index 60845249023..fc1f246c5f8 100644 --- a/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs +++ b/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs @@ -1553,10 +1553,10 @@ mod tests { mac_method: Option, ) -> (SasState, SasState) { let alice = Account::with_device_id(alice_id(), alice_device_id()); - let alice_device = ReadOnlyDevice::from_account(&alice).await; + let alice_device = ReadOnlyDevice::from_account(&alice); let bob = Account::with_device_id(bob_id(), bob_device_id()); - let bob_device = ReadOnlyDevice::from_account(&bob).await; + let bob_device = ReadOnlyDevice::from_account(&bob); let flow_id = TransactionId::new().into(); let alice_sas = SasState::::new( @@ -1805,7 +1805,7 @@ mod tests { } #[async_test] - async fn sas_unknown_method() { + async fn test_sas_unknown_method() { let (alice, bob) = get_sas_pair(None).await; let content = json!({ @@ -1824,12 +1824,12 @@ mod tests { } #[async_test] - async fn sas_from_start_unknown_method() { + async fn test_sas_from_start_unknown_method() { let alice = Account::with_device_id(alice_id(), alice_device_id()); - let alice_device = ReadOnlyDevice::from_account(&alice).await; + let alice_device = ReadOnlyDevice::from_account(&alice); let bob = Account::with_device_id(bob_id(), bob_device_id()); - let bob_device = ReadOnlyDevice::from_account(&bob).await; + let bob_device = ReadOnlyDevice::from_account(&bob); let flow_id = TransactionId::new().into(); let alice_sas = SasState::::new( diff --git a/crates/matrix-sdk-indexeddb/src/crypto_store.rs b/crates/matrix-sdk-indexeddb/src/crypto_store.rs index 84d8f481188..53d6984a97a 100644 --- a/crates/matrix-sdk-indexeddb/src/crypto_store.rs +++ b/crates/matrix-sdk-indexeddb/src/crypto_store.rs @@ -532,7 +532,7 @@ impl_crypto_store! { let account_pickle = if let Some(account) = changes.account { *self.static_account.write().unwrap() = Some(account.static_data().clone()); - Some(account.pickle().await) + Some(account.pickle()) } else { None }; diff --git a/crates/matrix-sdk-sqlite/src/crypto_store.rs b/crates/matrix-sdk-sqlite/src/crypto_store.rs index 4629d9cf1fa..cb815873a7a 100644 --- a/crates/matrix-sdk-sqlite/src/crypto_store.rs +++ b/crates/matrix-sdk-sqlite/src/crypto_store.rs @@ -693,7 +693,7 @@ impl CryptoStore for SqliteCryptoStore { let pickled_account = if let Some(account) = changes.account { *self.static_account.write().unwrap() = Some(account.static_data().clone()); - Some(account.pickle().await) + Some(account.pickle()) } else { None };