Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add ability to wait for locks and add locks to purge history / room deletion #15791

Merged
merged 11 commits into from
Jul 31, 2023
189 changes: 189 additions & 0 deletions synapse/handlers/worker_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
from weakref import WeakSet

import attr

from twisted.internet import defer
from twisted.internet.interfaces import IReactorTime

from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.databases.main.lock import Lock, LockStore
from synapse.util.async_helpers import timeout_deferred

if TYPE_CHECKING:
from synapse.server import HomeServer


class WorkerLocksHandler:
"""A class for waiting on taking out locks, rather than using the storage
functions directly (which don't support awaiting).
"""

def __init__(self, hs: "HomeServer") -> None:
self._reactor = hs.get_reactor()
self._store = hs.get_datastores().main
self._clock = hs.get_clock()
self._replication_handler = hs.get_replication_command_handler()
self._notifier = hs.get_notifier()

# Map from lock name/key to set of `WaitingLock` that are active for
# that lock.
self._locks: Dict[Tuple[str, str], WeakSet[WaitingLock]] = {}

self._clock.looping_call(self._cleanup_locks, 30_000)

self._notifier.add_lock_released_callback(self._on_lock_released)

def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock":
"""Acquire a standard lock, returns a context manager that will block
until the lock is acquired.

Usage:
async with handler.acquire_lock(name, key):
# Do work while holding the lock...
"""

lock = WaitingLock(
reactor=self._reactor,
store=self._store,
handler=self,
lock_name=lock_name,
lock_key=lock_key,
write=None,
)

self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)

return lock

def acquire_read_write_lock(
self,
lock_name: str,
lock_key: str,
*,
write: bool,
) -> "WaitingLock":
"""Acquire a read/write lock, returns a context manager that will block
until the lock is acquired.

Usage:
async with handler.acquire_read_write_lock(name, key, write=True):
# Do work while holding the lock...
"""

lock = WaitingLock(
reactor=self._reactor,
store=self._store,
handler=self,
lock_name=lock_name,
lock_key=lock_key,
write=write,
)

self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)

return lock

def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
"""Notify that a lock has been released.

Pokes both the notifier and replication.
"""

self._replication_handler.send_lock_released(lock_name, lock_key)
self._notifier.notify_lock_released(lock_name, lock_key)

def _on_lock_released(self, lock_name: str, lock_key: str) -> None:
"""Called when a lock has been released.

Wakes up any locks that might bew waiting on this.
"""
locks = self._locks.get((lock_name, lock_key))
if not locks:
return

def _wake_deferred(deferred: defer.Deferred) -> None:
if not deferred.called:
deferred.callback(None)

for lock in locks:
self._clock.call_later(0, _wake_deferred, lock.deferred)

@wrap_as_background_process("_cleanup_locks")
async def _cleanup_locks(self) -> None:
"""Periodically cleans out stale entries in the locks map"""
self._locks = {key: value for key, value in self._locks.items() if value}


@attr.s(auto_attribs=True, eq=False)
class WaitingLock:
reactor: IReactorTime
store: LockStore
handler: WorkerLocksHandler
lock_name: str
lock_key: str
write: Optional[bool]
deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
_inner_lock: Optional[Lock] = None
_retry_interval: float = 0.1

async def __aenter__(self) -> None:
while self._inner_lock is None:
self.deferred = defer.Deferred()

if self.write is not None:
lock = await self.store.try_acquire_read_write_lock(
self.lock_name, self.lock_key, write=self.write
)
else:
lock = await self.store.try_acquire_lock(self.lock_name, self.lock_key)

if lock:
self._inner_lock = lock
break

try:
with PreserveLoggingContext():
await timeout_deferred(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it took me a little to realise that this is polling at slightly random intervals, with a shortcut to wake up if we are notified the lock has been released. Maybe worth a comment?

deferred=self.deferred,
timeout=self._get_next_retry_interval(),
reactor=self.reactor,
)
except Exception:
pass

return await self._inner_lock.__aenter__()

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
assert self._inner_lock

self.handler.notify_lock_released(self.lock_name, self.lock_key)

return await self._inner_lock.__aexit__(exc_type, exc, tb)

def _get_next_retry_interval(self) -> float:
next = self._retry_interval
self._retry_interval = max(5, next * 2)
return next * random.uniform(0.9, 1.1)
12 changes: 12 additions & 0 deletions synapse/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ def __init__(self, hs: "HomeServer"):

self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules

# List of callbacks to be notified when a lock is released
self._lock_released_callback: List[Callable[[str, str], None]] = []

self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
self._pusher_pool = hs.get_pusherpool()
Expand Down Expand Up @@ -785,6 +788,15 @@ def notify_remote_server_up(self, server: str) -> None:
# that any in flight requests can be immediately retried.
self._federation_client.wake_destination(server)

def add_lock_released_callback(self, callback: Callable[[str, str], None]) -> None:
"""Add a function to be called whenever we are notified about a released lock."""
self._lock_released_callback.append(callback)

def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
"""Notify the callbacks that a lock has been released."""
for cb in self._lock_released_callback:
cb(lock_name, lock_key)


@attr.s(auto_attribs=True)
class ReplicationNotifier:
Expand Down
31 changes: 31 additions & 0 deletions synapse/replication/tcp/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,34 @@ class RemoteServerUpCommand(_SimpleCommand):
NAME = "REMOTE_SERVER_UP"


class LockReleasedCommand(Command):
"""Sent to inform other instances that a given lock has been dropped.

Format::

LOCK_RELEASED ["<lock_name>", "<lock_key>"]
"""

NAME = "LOCK_RELEASED"

def __init__(
self,
lock_name: str,
lock_key: str,
):
self.lock_name = lock_name
self.lock_key = lock_key

@classmethod
def from_line(cls: Type["LockReleasedCommand"], line: str) -> "LockReleasedCommand":
lock_name, lock_key = json_decoder.decode(line)

return cls(lock_name, lock_key)

def to_line(self) -> str:
return json_encoder.encode([self.lock_name, self.lock_key])


_COMMANDS: Tuple[Type[Command], ...] = (
ServerCommand,
RdataCommand,
Expand All @@ -435,6 +463,7 @@ class RemoteServerUpCommand(_SimpleCommand):
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
LockReleasedCommand,
)

# Map of command name to command type.
Expand All @@ -448,6 +477,7 @@ class RemoteServerUpCommand(_SimpleCommand):
ErrorCommand.NAME,
PingCommand.NAME,
RemoteServerUpCommand.NAME,
LockReleasedCommand.NAME,
)

# The commands the client is allowed to send
Expand All @@ -461,6 +491,7 @@ class RemoteServerUpCommand(_SimpleCommand):
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
LockReleasedCommand.NAME,
)


Expand Down
11 changes: 11 additions & 0 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
ClearUserSyncsCommand,
Command,
FederationAckCommand,
LockReleasedCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
Expand Down Expand Up @@ -648,6 +649,12 @@ def on_REMOTE_SERVER_UP(

self._notifier.notify_remote_server_up(cmd.data)

def on_LOCK_RELEASED(
self, conn: IReplicationConnection, cmd: LockReleasedCommand
) -> None:
"""Called when we get a new LOCK_RELEASED command."""
self._notifier.notify_lock_released(cmd.lock_name, cmd.lock_key)

def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection."""
self._connections.append(connection)
Expand Down Expand Up @@ -754,6 +761,10 @@ def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> No
"""
self.send_command(RdataCommand(stream_name, self._instance_name, token, data))

def send_lock_released(self, lock_name: str, lock_key: str) -> None:
"""Called when we released a lock and should notify other instances."""
self.send_command(LockReleasedCommand(lock_name, lock_key))


UpdateToken = TypeVar("UpdateToken")
UpdateRow = TypeVar("UpdateRow")
Expand Down
5 changes: 5 additions & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.handlers.worker_lock import WorkerLocksHandler
from synapse.http.client import (
InsecureInterceptableContextFactory,
ReplicationClient,
Expand Down Expand Up @@ -912,3 +913,7 @@ def get_request_ratelimiter(self) -> RequestRatelimiter:
def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager:
"""Usage metrics shared between phone home stats and the prometheus exporter."""
return CommonUsageMetricsManager(self)

@cache_in_self
def get_worker_locks_handler(self) -> WorkerLocksHandler:
return WorkerLocksHandler(self)