Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Keep fallback key marked as used if it's re-uploaded (#11382)
Browse files Browse the repository at this point in the history
  • Loading branch information
uhoreg authored Nov 19, 2021
1 parent e2e9bea commit eca7cff
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 12 deletions.
1 change: 1 addition & 0 deletions changelog.d/11382.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Keep fallback key marked as used if it's re-uploaded.
51 changes: 40 additions & 11 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,29 +408,58 @@ async def set_e2e_fallback_keys(
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
await self.db_pool.runInteraction(
"set_e2e_fallback_keys_txn",
self._set_e2e_fallback_keys_txn,
user_id,
device_id,
fallback_keys,
)

await self.invalidate_cache_and_stream(
"get_e2e_unused_fallback_key_types", (user_id, device_id)
)

def _set_e2e_fallback_keys_txn(
self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
) -> None:
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
await self.db_pool.simple_upsert(
"e2e_fallback_keys_json",
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
values={
"key_id": key_id,
"key_json": json_encoder.encode(fallback_key),
"used": False,
},
desc="set_e2e_fallback_key",
retcol="key_json",
allow_none=True,
)

await self.invalidate_cache_and_stream(
"get_e2e_unused_fallback_key_types", (user_id, device_id)
)
new_key_json = encode_canonical_json(fallback_key).decode("utf-8")

# If the uploaded key is the same as the current fallback key,
# don't do anything. This prevents marking the key as unused if it
# was already used.
if old_key_json != new_key_json:
self.db_pool.simple_upsert_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
values={
"key_id": key_id,
"key_json": json_encoder.encode(fallback_key),
"used": False,
},
)

@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
Expand Down
32 changes: 31 additions & 1 deletion tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def test_fallback_key(self):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "key1"}
fallback_key2 = {"alg1:k2": "key2"}
otk = {"alg1:k2": "key2"}

# we shouldn't have any unused fallback keys yet
Expand Down Expand Up @@ -213,6 +214,35 @@ def test_fallback_key(self):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)

# re-uploading the same fallback key should still result in no unused fallback
# keys
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"org.matrix.msc2732.fallback_keys": fallback_key},
)
)

res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])

# uploading a new fallback key should result in an unused fallback key
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"org.matrix.msc2732.fallback_keys": fallback_key2},
)
)

res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, ["alg1"])

# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
self.get_success(
Expand All @@ -238,7 +268,7 @@ def test_fallback_key(self):
)
self.assertEqual(
res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)

def test_replace_master_key(self):
Expand Down

0 comments on commit eca7cff

Please sign in to comment.