From 32cbfd2acd28bcefb97c442ac8e3ee2c07401e19 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Fri, 15 Nov 2019 14:56:30 -0800 Subject: [PATCH] New Pool/ConnectionGroup implementation Refs #569 --- datasette/database.py | 91 +++++++++++++++++++++++++++++++++++++++++++ tests/test_pool.py | 65 +++++++++++++++++++++++++++++++ 2 files changed, 156 insertions(+) create mode 100644 tests/test_pool.py diff --git a/datasette/database.py b/datasette/database.py index 9a8ae4d434..4886a10eac 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -327,3 +327,94 @@ def __repr__(self): if tags: tags_str = " ({})".format(", ".join(tags)) return "".format(self.name, tags_str) + + +class ConnectionError(Exception): + pass + + +class ConnectionTimeoutError(ConnectionError): + pass + + +class Connection: + def __init__(self, name, connect_args=None, connect_kwargs=None): + self.name = name + self.connect_args = connect_args or tuple() + self.connect_kwargs = connect_kwargs or {} + self.lock = threading.Lock() + self._connection = None + + def connection(self): + if self._connection is None: + self._connection = sqlite3.connect( + *self.connect_args, **self.connect_kwargs + ) + return self._connection + + def __repr__(self): + return "{} {} ({})".format(self.name, self._connection, self.lock) + + +class ConnectionGroup: + timeout = 2 + + def __init__(self, name, connect_args, connect_kwargs=None, limit=3): + self.name = name + self.connections = [ + Connection(name, connect_args, connect_kwargs) for _ in range(limit) + ] + self.limit = limit + self._semaphore = threading.Semaphore(value=limit) + + @contextlib.contextmanager + def connection(self): + semaphore_aquired = False + reserved_connection = None + try: + semaphore_aquired = self._semaphore.acquire(timeout=self.timeout) + if not semaphore_aquired: + raise ConnectionTimeoutError( + "Timed out after {}s waiting for connection '{}'".format( + self.timeout, self.name + ) + ) + # Loop through connections attempting to aquire a lock + for connection in self.connections: + lock = connection.lock + if lock.acquire(False): + # We acquired the lock! use this one + reserved_connection = connection + break + else: + # If we get here, we failed to lock a connection even though + # the semaphore should have guaranteed it + raise ConnectionError( + "Failed to lock a connection despite the sempahore" + ) + # We should have a connection now - yield it, then clean up locks + yield reserved_connection + finally: + reserved_connection.lock.release() + if semaphore_aquired: + self._semaphore.release() + + +class Pool: + def __init__(self, databases=None, max_connections_per_database=3): + self.max_connections_per_database = max_connections_per_database + self.databases = {} + self.connection_groups = {} + for key, value in (databases or {}).items(): + self.add_database(key, value) + + def add_database(self, name, filepath): + self.databases[name] = filepath + self.connection_groups[name] = ConnectionGroup( + name, [filepath], limit=self.max_connections_per_database + ) + + @contextlib.contextmanager + def connection(self, name): + with self.connection_groups[name].connection() as conn: + yield conn diff --git a/tests/test_pool.py b/tests/test_pool.py new file mode 100644 index 0000000000..70407d3cb8 --- /dev/null +++ b/tests/test_pool.py @@ -0,0 +1,65 @@ +import threading +import time + +import pytest + +from datasette.database import Pool + + +def test_lock_connection(): + pool = Pool({"one": ":memory:"}) + with pool.connection("one") as conn: + assert conn.lock.locked() + assert not conn.lock.locked() + + +def test_connect_if_one_connection_is_locked(): + pool = Pool({"one": ":memory:"}) + connections = pool.connection_groups["one"].connections + assert 3 == len(connections) + # They should all start unlocked: + assert all(not c.lock.locked() for c in connections) + # Now lock one for the duration of this test + first_connection = connections[0] + try: + first_connection.lock.acquire() + # This should give us a different connection + with pool.connection("one") as conn: + assert conn is not first_connection + assert conn.lock.locked() + # There should be only one UNLOCKED connection now + assert 1 == len([c for c in connections if not c.lock.locked()]) + finally: + first_connection.lock.release() + # At this point, all connections should be unlocked + assert 3 == len([c for c in connections if not c.lock.locked()]) + + +def test_block_until_connection_is_released(): + # If all connections are already in use, block until one is released + pool = Pool({"one": ":memory:"}, max_connections_per_database=1) + connections = pool.connection_groups["one"].connections + assert 1 == len(connections) + + def block_connection(pool): + with pool.connection("one"): + time.sleep(0.05) + + t = threading.Thread(target=block_connection, args=[pool]) + t.start() + # Give thread time to grab the connection: + time.sleep(0.01) + # Thread should now have grabbed and locked a connection: + assert 1 == len([c for c in connections if c.lock.locked()]) + + start = time.time() + # Now we attempt to use the connection. This should block. + with pool.connection("one") as conn: + # At this point, over 0.02 seconds should have passed + assert (time.time() - start) > 0.02 + assert conn.lock.locked() + + # Ensure thread has run to completion before ending test: + t.join() + # Connections should all be unlocked at the end + assert all(not c.lock.locked() for c in connections)