Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor to ensure we call check_consistency #9470

Merged
merged 5 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9470.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix missing startup checks for the consistency of certain PostgreSQL sequences.
16 changes: 4 additions & 12 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection

# python 3 does not have a maximum int value
Expand Down Expand Up @@ -381,7 +380,10 @@ class DatabasePool:
_TXN_ID = 0

def __init__(
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
self,
hs,
database_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
):
self.hs = hs
self._clock = hs.get_clock()
Expand Down Expand Up @@ -420,16 +422,6 @@ def __init__(
self._check_safe_to_upsert,
)

# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]

self.event_chain_id_gen = build_sequence_generator(
engine, get_chain_id_txn, "event_auth_chain_id"
)

def is_running(self) -> bool:
"""Is the database pool currently running"""
return self._db_pool.running
Expand Down
13 changes: 6 additions & 7 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
Expand Down Expand Up @@ -114,12 +115,6 @@ def __init__(
) # type: MultiWriterIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator

# The consistency of this cannot be checked when the ID generator is
# created since the database might not yet be up-to-date.
self.db_pool.event_chain_id_gen.check_consistency(
db_conn, "event_auth_chains", "chain_id" # type: ignore
)

# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
Expand Down Expand Up @@ -485,6 +480,7 @@ def _persist_event_auth_chain_txn(
self._add_chain_cover_index(
txn,
self.db_pool,
self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
Expand All @@ -495,6 +491,7 @@ def _add_chain_cover_index(
cls,
txn,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
Expand Down Expand Up @@ -641,6 +638,7 @@ def _add_chain_cover_index(
new_chain_tuples = cls._allocate_chain_ids(
txn,
db_pool,
event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
Expand Down Expand Up @@ -779,6 +777,7 @@ def _add_chain_cover_index(
def _allocate_chain_ids(
txn,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
Expand Down Expand Up @@ -891,7 +890,7 @@ def _allocate_chain_ids(
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]

# Generate new chain IDs for all unallocated chain IDs.
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids)
)

Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/events_bg_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ def _calculate_chain_cover_txn(
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
self.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
Expand Down
16 changes: 16 additions & 0 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
Expand Down Expand Up @@ -156,6 +157,21 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self._event_fetch_list = []
self._event_fetch_ongoing = 0

# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]

self.event_chain_id_gen = build_sequence_generator(
db_conn,
database.engine,
get_chain_id_txn,
"event_auth_chain_id",
table="event_auth_chains",
id_column="chain_id",
)

def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
Expand Down
19 changes: 16 additions & 3 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
Expand Down Expand Up @@ -70,7 +70,12 @@ def _default_token_owner(self):


class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember you and @richvdh having a conversation about switching these from Connection to LoggingDatabaseConnection, but do not remember the OK. Are we OK making this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we are

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

he convinced me I think

hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.config = hs.config
Expand All @@ -79,9 +84,12 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"
# call `find_max_generated_user_id_localpart` each time, which is
# expensive if there are many entries.
self._user_id_seq = build_sequence_generator(
db_conn,
database.engine,
find_max_generated_user_id_localpart,
"user_id_seq",
table=None,
id_column=None,
)

self._account_validity = hs.config.account_validity
Expand Down Expand Up @@ -1036,7 +1044,12 @@ async def update_access_token_last_validated(self, token_id: int) -> None:


class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self._clock = hs.get_clock()
Expand Down
10 changes: 6 additions & 4 deletions synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,12 @@ def get_max_state_group_txn(txn: Cursor):
return txn.fetchone()[0]

self._state_group_seq_gen = build_sequence_generator(
self.database_engine, get_max_state_group_txn, "state_group_id_seq"
)
self._state_group_seq_gen.check_consistency(
db_conn, table="state_groups", id_column="id"
db_conn,
self.database_engine,
get_max_state_group_txn,
"state_group_id_seq",
table="state_groups",
id_column="id",
)

@cached(max_entries=10000, iterable=True)
Expand Down
24 changes: 22 additions & 2 deletions synapse/storage/util/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,14 @@ def check_consistency(


def build_sequence_generator(
db_conn: "LoggingDatabaseConnection",
database_engine: BaseDatabaseEngine,
get_first_callback: GetFirstCallbackType,
sequence_name: str,
table: Optional[str],
id_column: Optional[str],
stream_name: Optional[str] = None,
positive: bool = True,
) -> SequenceGenerator:
"""Get the best impl of SequenceGenerator available

Expand All @@ -265,8 +270,23 @@ def build_sequence_generator(
get_first_callback: a callback which gets the next sequence ID. Used if
we're on sqlite.
sequence_name: the name of a postgres sequence to use.
table, id_column, stream_name, positive: If set then `check_consistency`
is called on the created sequence. See docstring for
`check_consistency` details.
"""
if isinstance(database_engine, PostgresEngine):
return PostgresSequenceGenerator(sequence_name)
seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator
else:
return LocalSequenceGenerator(get_first_callback)
seq = LocalSequenceGenerator(get_first_callback)

if table:
assert id_column
seq.check_consistency(
db_conn=db_conn,
table=table,
id_column=id_column,
stream_name=stream_name,
positive=positive,
)

return seq