Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert a synapse.events to async/await. (#7949)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Jul 27, 2020
1 parent 5f65e62 commit 8553f46
Show file tree
Hide file tree
Showing 13 changed files with 86 additions and 82 deletions.
2 changes: 1 addition & 1 deletion changelog.d/7948.misc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Convert push to async/await.
Convert various parts of the codebase to async/await.
1 change: 1 addition & 0 deletions changelog.d/7949.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
2 changes: 1 addition & 1 deletion changelog.d/7951.misc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Convert groups and visibility code to async / await.
Convert various parts of the codebase to async/await.
2 changes: 1 addition & 1 deletion synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, hs):

@defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
Expand Down
19 changes: 8 additions & 11 deletions synapse/events/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import attr
from nacl.signing import SigningKey

from twisted.internet import defer

from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import (
Expand Down Expand Up @@ -95,31 +93,30 @@ def state_key(self):
def is_state(self):
return self._state_key is not None

@defer.inlineCallbacks
def build(self, prev_event_ids):
async def build(self, prev_event_ids):
"""Transform into a fully signed and hashed event
Args:
prev_event_ids (list[str]): The event IDs to use as the prev events
Returns:
Deferred[FrozenEvent]
FrozenEvent
"""

state_ids = yield defer.ensureDeferred(
self._state.get_current_state_ids(self.room_id, prev_event_ids)
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
auth_ids = yield self._auth.compute_auth_events(self, state_ids)
auth_ids = await self._auth.compute_auth_events(self, state_ids)

format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids)
prev_events = yield self._store.add_event_hashes(prev_event_ids)
auth_events = await self._store.add_event_hashes(auth_ids)
prev_events = await self._store.add_event_hashes(prev_event_ids)
else:
auth_events = auth_ids
prev_events = prev_event_ids

old_depth = yield self._store.get_max_depth_of(prev_event_ids)
old_depth = await self._store.get_max_depth_of(prev_event_ids)
depth = old_depth + 1

# we cap depth of generated events, to ensure that they are not
Expand Down
46 changes: 22 additions & 24 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

import attr
from frozendict import frozendict

from twisted.internet import defer

from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap

if TYPE_CHECKING:
from synapse.storage.data_stores.main import DataStore


@attr.s(slots=True)
class EventContext:
Expand Down Expand Up @@ -129,8 +131,7 @@ def with_state(
delta_ids=delta_ids,
)

@defer.inlineCallbacks
def serialize(self, event, store):
async def serialize(self, event: EventBase, store: "DataStore") -> dict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
Expand All @@ -146,7 +147,7 @@ def serialize(self, event, store):
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
prev_state_ids = yield self.get_prev_state_ids()
prev_state_ids = await self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None
Expand Down Expand Up @@ -214,8 +215,7 @@ def state_group(self) -> Optional[int]:

return self._state_group

@defer.inlineCallbacks
def get_current_state_ids(self):
async def get_current_state_ids(self) -> Optional[StateMap[str]]:
"""
Gets the room state map, including this event - ie, the state in ``state_group``
Expand All @@ -224,32 +224,31 @@ def get_current_state_ids(self):
``rejected`` is set.
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
Returns None if state_group is None, which happens when the associated
event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")

yield self._ensure_fetched()
await self._ensure_fetched()
return self._current_state_ids

@defer.inlineCallbacks
def get_prev_state_ids(self):
async def get_prev_state_ids(self):
"""
Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids().
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
dict[(str, str), str]|None: Returns None if state_group
is None, which happens when the associated event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
yield self._ensure_fetched()
await self._ensure_fetched()
return self._prev_state_ids

def get_cached_current_state_ids(self):
Expand All @@ -269,8 +268,8 @@ def get_cached_current_state_ids(self):

return self._current_state_ids

def _ensure_fetched(self):
return defer.succeed(None)
async def _ensure_fetched(self):
return None


@attr.s(slots=True)
Expand Down Expand Up @@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext):
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)

def _ensure_fetched(self):
async def _ensure_fetched(self):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state)

return make_deferred_yieldable(self._fetching_state_deferred)
return await make_deferred_yieldable(self._fetching_state_deferred)

@defer.inlineCallbacks
def _fill_out_state(self):
async def _fill_out_state(self):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return

self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
self._current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group
)
if self._event_state_key is not None:
Expand Down
55 changes: 30 additions & 25 deletions synapse/events/third_party_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Requester


class ThirdPartyEventRules(object):
Expand All @@ -39,76 +41,79 @@ def __init__(self, hs):
config=config, http_client=hs.get_simple_http_client()
)

@defer.inlineCallbacks
def check_event_allowed(self, event, context):
async def check_event_allowed(
self, event: EventBase, context: EventContext
) -> bool:
"""Check if a provided event should be allowed in the given context.
Args:
event (synapse.events.EventBase): The event to be checked.
context (synapse.events.snapshot.EventContext): The context of the event.
event: The event to be checked.
context: The context of the event.
Returns:
defer.Deferred[bool]: True if the event should be allowed, False if not.
True if the event should be allowed, False if not.
"""
if self.third_party_rules is None:
return True

prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()

# Retrieve the state events from the database.
state_events = {}
for key, event_id in prev_state_ids.items():
state_events[key] = yield self.store.get_event(event_id, allow_none=True)
state_events[key] = await self.store.get_event(event_id, allow_none=True)

ret = yield self.third_party_rules.check_event_allowed(event, state_events)
ret = await self.third_party_rules.check_event_allowed(event, state_events)
return ret

@defer.inlineCallbacks
def on_create_room(self, requester, config, is_requester_admin):
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
) -> bool:
"""Intercept requests to create room to allow, deny or update the
request config.
Args:
requester (Requester)
config (dict): The creation config from the client.
is_requester_admin (bool): If the requester is an admin
requester
config: The creation config from the client.
is_requester_admin: If the requester is an admin
Returns:
defer.Deferred[bool]: Whether room creation is allowed or denied.
Whether room creation is allowed or denied.
"""

if self.third_party_rules is None:
return True

ret = yield self.third_party_rules.on_create_room(
ret = await self.third_party_rules.on_create_room(
requester, config, is_requester_admin
)
return ret

@defer.inlineCallbacks
def check_threepid_can_be_invited(self, medium, address, room_id):
async def check_threepid_can_be_invited(
self, medium: str, address: str, room_id: str
) -> bool:
"""Check if a provided 3PID can be invited in the given room.
Args:
medium (str): The 3PID's medium.
address (str): The 3PID's address.
room_id (str): The room we want to invite the threepid to.
medium: The 3PID's medium.
address: The 3PID's address.
room_id: The room we want to invite the threepid to.
Returns:
defer.Deferred[bool], True if the 3PID can be invited, False if not.
True if the 3PID can be invited, False if not.
"""

if self.third_party_rules is None:
return True

state_ids = yield self.store.get_filtered_current_state_ids(room_id)
room_state_events = yield self.store.get_events(state_ids.values())
state_ids = await self.store.get_filtered_current_state_ids(room_id)
room_state_events = await self.store.get_events(state_ids.values())

state_events = {}
for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id]

ret = yield self.third_party_rules.check_threepid_can_be_invited(
ret = await self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events
)
return ret
15 changes: 7 additions & 8 deletions synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from frozendict import frozendict

from twisted.internet import defer

from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
Expand Down Expand Up @@ -337,8 +335,9 @@ def __init__(self, hs):
hs.config.experimental_msc1849_support_enabled
)

@defer.inlineCallbacks
def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
async def serialize_event(
self, event, time_now, bundle_aggregations=True, **kwargs
):
"""Serializes a single event.
Args:
Expand All @@ -348,7 +347,7 @@ def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
**kwargs: Arguments to pass to `serialize_event`
Returns:
Deferred[dict]: The serialized event
dict: The serialized event
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
Expand All @@ -363,8 +362,8 @@ def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
if not event.internal_metadata.is_redacted() and (
self.experimental_msc1849_support_enabled and bundle_aggregations
):
annotations = yield self.store.get_aggregation_groups_for_event(event_id)
references = yield self.store.get_relations_for_event(
annotations = await self.store.get_aggregation_groups_for_event(event_id)
references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f"
)

Expand All @@ -378,7 +377,7 @@ def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):

edit = None
if event.type == EventTypes.Message:
edit = yield self.store.get_applicable_edit(event_id)
edit = await self.store.get_applicable_edit(event_id)

if edit:
# If there is an edit replace the content, preserving existing
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2470,7 +2470,7 @@ async def _update_context_for_auth_events(
}

current_state_ids = await context.get_current_state_ids()
current_state_ids = dict(current_state_ids)
current_state_ids = dict(current_state_ids) # type: ignore

current_state_ids.update(state_updates)

Expand Down
4 changes: 3 additions & 1 deletion synapse/replication/http/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def _serialize_payload(store, event_and_contexts, backfilled):
"""
event_payloads = []
for event, context in event_and_contexts:
serialized_context = yield context.serialize(event, store)
serialized_context = yield defer.ensureDeferred(
context.serialize(event, store)
)

event_payloads.append(
{
Expand Down
2 changes: 1 addition & 1 deletion synapse/replication/http/send_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _serialize_payload(
extra_users (list(UserID)): Any extra users to notify about event
"""

serialized_context = yield context.serialize(event, store)
serialized_context = yield defer.ensureDeferred(context.serialize(event, store))

payload = {
"event": event.get_pdu_json(),
Expand Down
Loading

0 comments on commit 8553f46

Please sign in to comment.