From c8139e02039432824a5b2cf07fda3f03f0cbf027 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 8 Apr 2022 17:00:28 +0100 Subject: [PATCH] No untyped defs in `synapse_port_db` This was by far the most painful. I'm happy to break this up into smaller pieces for review if it's not managable as-is. --- synapse/_scripts/synapse_port_db.py | 223 ++++++++++++++++++---------- 1 file changed, 148 insertions(+), 75 deletions(-) diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 123eaae5c56b..5099ef326c38 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +:!/usr/bin/env python # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. @@ -22,7 +22,23 @@ import time import traceback from types import TracebackType -from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generator, + Iterable, + List, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypedDict, + TypeVar, + cast, +) import yaml from matrix_common.versionstring import get_distribution_version_string @@ -36,7 +52,7 @@ make_deferred_yieldable, run_in_background, ) -from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn from synapse.storage.databases.main import PushRuleStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore @@ -173,6 +189,8 @@ Tuple[Type[BaseException], BaseException, TracebackType] ] = None +R = TypeVar("R") + class Store( ClientIpBackgroundUpdateStore, @@ -195,17 +213,19 @@ class Store( PresenceBackgroundUpdateStore, GroupServerWorkerStore, ): - def execute(self, f, *args, **kwargs): + def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) - def execute_sql(self, sql, *args): - def r(txn): + def execute_sql(self, sql: str, *args: object) -> Awaitable[List[Tuple]]: + def r(txn: LoggingTransaction) -> List[Tuple]: txn.execute(sql, args) return txn.fetchall() return self.db_pool.runInteraction("execute_sql", r) - def insert_many_txn(self, txn, table, headers, rows): + def insert_many_txn( + self, txn: LoggingTransaction, table: str, headers: List[str], rows: List[Tuple] + ) -> None: sql = "INSERT INTO %s (%s) VALUES (%s)" % ( table, ", ".join(k for k in headers), @@ -218,14 +238,15 @@ def insert_many_txn(self, txn, table, headers, rows): logger.exception("Failed to insert: %s", table) raise - def set_room_is_public(self, room_id, is_public): + # Note: the parent method is an `async def`. + def set_room_is_public(self, room_id: str, is_public: bool) -> NoReturn: raise Exception( "Attempt to set room_is_public during port_db: database not empty?" ) class MockHomeserver: - def __init__(self, config): + def __init__(self, config: HomeServerConfig): self.clock = Clock(reactor) self.config = config self.hostname = config.server.server_name @@ -233,24 +254,30 @@ def __init__(self, config): "matrix-synapse" ) - def get_clock(self): + def get_clock(self) -> Clock: return self.clock - def get_reactor(self): + def get_reactor(self) -> ISynapseReactor: return reactor - def get_instance_name(self): + def get_instance_name(self) -> str: return "master" class Porter: - def __init__(self, sqlite_config, progress, batch_size, hs_config): + def __init__( + self, + sqlite_config: Dict[str, Any], + progress: "Progress", + batch_size: int, + hs_config: HomeServerConfig, + ): self.sqlite_config = sqlite_config self.progress = progress self.batch_size = batch_size self.hs_config = hs_config - async def setup_table(self, table): + async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]: if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. row = await self.postgres_store.db_pool.simple_select_one( @@ -292,7 +319,7 @@ async def setup_table(self, table): ) else: - def delete_all(txn): + def delete_all(txn: LoggingTransaction) -> None: txn.execute( "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,) ) @@ -317,7 +344,7 @@ def delete_all(txn): async def get_table_constraints(self) -> Dict[str, Set[str]]: """Returns a map of tables that have foreign key constraints to tables they depend on.""" - def _get_constraints(txn): + def _get_constraints(txn: LoggingTransaction) -> Dict[str, Set[str]]: # We can pull the information about foreign key constraints out from # the postgres schema tables. sql = """ @@ -343,8 +370,13 @@ def _get_constraints(txn): ) async def handle_table( - self, table, postgres_size, table_size, forward_chunk, backward_chunk - ): + self, + table: str, + postgres_size: int, + table_size: int, + forward_chunk: int, + backward_chunk: int, + ) -> None: logger.info( "Table %s: %i/%i (rows %i-%i) already ported", table, @@ -391,7 +423,9 @@ async def handle_table( while True: - def r(txn): + def r( + txn: LoggingTransaction, + ) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]: forward_rows = [] backward_rows = [] if do_forward[0]: @@ -416,6 +450,7 @@ def r(txn): headers, frows, brows = await self.sqlite_store.db_pool.runInteraction( "select", r ) + assert headers is not None if frows or brows: if frows: @@ -426,7 +461,8 @@ def r(txn): rows = frows + brows rows = self._convert_rows(table, headers, rows) - def insert(txn): + def insert(txn: LoggingTransaction) -> None: + assert headers is not None self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) self.postgres_store.db_pool.simple_update_one_txn( @@ -448,8 +484,12 @@ def insert(txn): return async def handle_search_table( - self, postgres_size, table_size, forward_chunk, backward_chunk - ): + self, + postgres_size: int, + table_size: int, + forward_chunk: int, + backward_chunk: int, + ) -> None: select = ( "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" " FROM event_search as es" @@ -460,7 +500,7 @@ async def handle_search_table( while True: - def r(txn): + def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select, (forward_chunk, self.batch_size)) rows = txn.fetchall() headers = [column[0] for column in txn.description] @@ -474,7 +514,7 @@ def r(txn): # We have to treat event_search differently since it has a # different structure in the two different databases. - def insert(txn): + def insert(txn: LoggingTransaction) -> None: sql = ( "INSERT INTO event_search (event_id, room_id, key," " sender, vector, origin_server_ts, stream_ordering)" @@ -528,7 +568,7 @@ def build_db_store( self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False, - ): + ) -> Store: """Builds and returns a database store using the provided configuration. Args: @@ -556,7 +596,7 @@ def build_db_store( return store - async def run_background_updates_on_postgres(self): + async def run_background_updates_on_postgres(self) -> None: # Manually apply all background updates on the PostgreSQL database. postgres_ready = ( await self.postgres_store.db_pool.updates.has_completed_background_updates() @@ -568,12 +608,12 @@ async def run_background_updates_on_postgres(self): self.progress.set_state("Running background updates on PostgreSQL") while not postgres_ready: - await self.postgres_store.db_pool.updates.do_next_background_update(100) + await self.postgres_store.db_pool.updates.do_next_background_update(True) postgres_ready = await ( self.postgres_store.db_pool.updates.has_completed_background_updates() ) - async def run(self): + async def run(self) -> None: """Ports the SQLite database to a PostgreSQL database. When a fatal error is met, its message is assigned to the global "end_error" @@ -609,7 +649,7 @@ async def run(self): self.progress.set_state("Creating port tables") - def create_port_table(txn): + def create_port_table(txn: LoggingTransaction) -> None: txn.execute( "CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" " table_name varchar(100) NOT NULL UNIQUE," @@ -622,7 +662,7 @@ def create_port_table(txn): # We want people to be able to rerun this script from an old port # so that they can pick up any missing events that were not # ported across. - def alter_table(txn): + def alter_table(txn: LoggingTransaction) -> None: txn.execute( "ALTER TABLE IF EXISTS port_from_sqlite3" " RENAME rowid TO forward_rowid" @@ -742,7 +782,9 @@ def alter_table(txn): finally: reactor.stop() - def _convert_rows(self, table, headers, rows): + def _convert_rows( + self, table: str, headers: List[str], rows: List[Tuple] + ) -> List[Tuple]: bool_col_names = BOOLEAN_COLUMNS.get(table, []) bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names] @@ -750,7 +792,7 @@ def _convert_rows(self, table, headers, rows): class BadValueException(Exception): pass - def conv(j, col): + def conv(j: int, col: object) -> object: if j in bool_cols: return bool(col) if isinstance(col, bytes): @@ -776,7 +818,7 @@ def conv(j, col): return outrows - async def _setup_sent_transactions(self): + async def _setup_sent_transactions(self) -> Tuple[int, int, int]: # Only save things from the last day yesterday = int(time.time() * 1000) - 86400000 @@ -788,10 +830,10 @@ async def _setup_sent_transactions(self): ")" ) - def r(txn): + def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select) rows = txn.fetchall() - headers = [column[0] for column in txn.description] + headers: List[str] = [column[0] for column in txn.description] ts_ind = headers.index("ts") @@ -805,7 +847,7 @@ def r(txn): if inserted_rows: max_inserted_rowid = max(r[0] for r in rows) - def insert(txn): + def insert(txn: LoggingTransaction) -> None: self.postgres_store.insert_many_txn( txn, "sent_transactions", headers[1:], rows ) @@ -814,7 +856,7 @@ def insert(txn): else: max_inserted_rowid = 0 - def get_start_id(txn): + def get_start_id(txn: LoggingTransaction) -> int: txn.execute( "SELECT rowid FROM sent_transactions WHERE ts >= ?" " ORDER BY rowid ASC LIMIT 1", @@ -839,12 +881,13 @@ def get_start_id(txn): }, ) - def get_sent_table_size(txn): + def get_sent_table_size(txn: LoggingTransaction) -> int: txn.execute( "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) ) - (size,) = txn.fetchone() - return int(size) + result = txn.fetchone() + assert result is not None + return int(result[0]) remaining_count = await self.sqlite_store.execute(get_sent_table_size) @@ -852,25 +895,35 @@ def get_sent_table_size(txn): return next_chunk, inserted_rows, total_count - async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): - frows = await self.sqlite_store.execute_sql( - "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk + async def _get_remaining_count_to_port( + self, table: str, forward_chunk: int, backward_chunk: int + ) -> int: + frows = cast( + List[Tuple[int]], + await self.sqlite_store.execute_sql( + "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk + ), ) - brows = await self.sqlite_store.execute_sql( - "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk + brows = cast( + List[Tuple[int]], + await self.sqlite_store.execute_sql( + "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk + ), ) return frows[0][0] + brows[0][0] - async def _get_already_ported_count(self, table): + async def _get_already_ported_count(self, table: str) -> int: rows = await self.postgres_store.execute_sql( "SELECT count(*) FROM %s" % (table,) ) return rows[0][0] - async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): + async def _get_total_count_to_port( + self, table: str, forward_chunk: int, backward_chunk: int + ) -> Tuple[int, int]: remaining, done = await make_deferred_yieldable( defer.gatherResults( [ @@ -891,14 +944,17 @@ async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): return done, remaining + done async def _setup_state_group_id_seq(self) -> None: - curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + curr_id: Optional[ + int + ] = await self.sqlite_store.db_pool.simple_select_one_onecol( table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True ) if not curr_id: return - def r(txn): + def r(txn: LoggingTransaction) -> None: + assert curr_id is not None next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) @@ -909,7 +965,7 @@ async def _setup_user_id_seq(self) -> None: "setup_user_id_seq", find_max_generated_user_id_localpart ) - def r(txn): + def r(txn: LoggingTransaction) -> None: next_id = curr_id + 1 txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) @@ -931,7 +987,7 @@ async def _setup_events_stream_seqs(self) -> None: allow_none=True, ) - def _setup_events_stream_seqs_set_pos(txn): + def _setup_events_stream_seqs_set_pos(txn: LoggingTransaction) -> None: if curr_forward_id: txn.execute( "ALTER SEQUENCE events_stream_seq RESTART WITH %s", @@ -955,17 +1011,20 @@ async def _setup_sequence( """Set a sequence to the correct value.""" current_stream_ids = [] for stream_id_table in stream_id_tables: - max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol( - table=stream_id_table, - keyvalues={}, - retcol="COALESCE(MAX(stream_id), 1)", - allow_none=True, + max_stream_id = cast( + int, + await self.sqlite_store.db_pool.simple_select_one_onecol( + table=stream_id_table, + keyvalues={}, + retcol="COALESCE(MAX(stream_id), 1)", + allow_none=True, + ), ) current_stream_ids.append(max_stream_id) next_id = max(current_stream_ids) + 1 - def r(txn): + def r(txn: LoggingTransaction) -> None: sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,) txn.execute(sql + " %s", (next_id,)) @@ -974,14 +1033,18 @@ def r(txn): ) async def _setup_auth_chain_sequence(self) -> None: - curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + curr_chain_id: Optional[ + int + ] = await self.sqlite_store.db_pool.simple_select_one_onecol( table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", allow_none=True, ) - def r(txn): + def r(txn: LoggingTransaction) -> None: + # Presumably there is at least one row in event_auth_chains. + assert curr_chain_id is not None txn.execute( "ALTER SEQUENCE event_auth_chain_id RESTART WITH %s", (curr_chain_id + 1,), @@ -999,15 +1062,22 @@ def r(txn): ############################################## -class Progress(object): +class TableProgress(TypedDict): + start: int + num_done: int + total: int + perc: int + + +class Progress: """Used to report progress of the port""" - def __init__(self): - self.tables = {} + def __init__(self) -> None: + self.tables: Dict[str, TableProgress] = {} self.start_time = int(time.time()) - def add_table(self, table, cur, size): + def add_table(self, table: str, cur: int, size: int) -> None: self.tables[table] = { "start": cur, "num_done": cur, @@ -1015,19 +1085,22 @@ def add_table(self, table, cur, size): "perc": int(cur * 100 / size), } - def update(self, table, num_done): + def update(self, table: str, num_done: int) -> None: data = self.tables[table] data["num_done"] = num_done data["perc"] = int(num_done * 100 / data["total"]) - def done(self): + def done(self) -> None: + pass + + def set_state(self, state: str) -> None: pass class CursesProgress(Progress): """Reports progress to a curses window""" - def __init__(self, stdscr): + def __init__(self, stdscr: curses.window): self.stdscr = stdscr curses.use_default_colors() @@ -1045,7 +1118,7 @@ def __init__(self, stdscr): super(CursesProgress, self).__init__() - def update(self, table, num_done): + def update(self, table: str, num_done: int) -> None: super(CursesProgress, self).update(table, num_done) self.total_processed = 0 @@ -1056,7 +1129,7 @@ def update(self, table, num_done): self.render() - def render(self, force=False): + def render(self, force: bool = False) -> None: now = time.time() if not force and now - self.last_update < 0.2: @@ -1128,12 +1201,12 @@ def render(self, force=False): self.stdscr.refresh() self.last_update = time.time() - def done(self): + def done(self) -> None: self.finished = True self.render(True) self.stdscr.getch() - def set_state(self, state): + def set_state(self, state: str) -> None: self.stdscr.clear() self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD) self.stdscr.refresh() @@ -1142,7 +1215,7 @@ def set_state(self, state): class TerminalProgress(Progress): """Just prints progress to the terminal""" - def update(self, table, num_done): + def update(self, table: str, num_done: int) -> None: super(TerminalProgress, self).update(table, num_done) data = self.tables[table] @@ -1151,7 +1224,7 @@ def update(self, table, num_done): "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"]) ) - def set_state(self, state): + def set_state(self, state: str) -> None: print(state + "...") @@ -1159,7 +1232,7 @@ def set_state(self, state): ############################################## -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="A script to port an existing synapse SQLite database to" " a new PostgreSQL database." @@ -1225,7 +1298,7 @@ def main(): config = HomeServerConfig() config.parse_config_dict(hs_config, "", "") - def start(stdscr=None): + def start(stdscr: Optional[curses.window] = None) -> None: progress: Progress if stdscr: progress = CursesProgress(stdscr) @@ -1240,7 +1313,7 @@ def start(stdscr=None): ) @defer.inlineCallbacks - def run(): + def run() -> Generator["defer.Deferred[Any]", Any, None]: with LoggingContext("synapse_port_db_run"): yield defer.ensureDeferred(porter.run())