Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New connection pooling #579

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions datasette/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,94 @@ def __repr__(self):
if tags:
tags_str = " ({})".format(", ".join(tags))
return "<Database: {}{}>".format(self.name, tags_str)


class ConnectionError(Exception):
pass


class ConnectionTimeoutError(ConnectionError):
pass


class Connection:
def __init__(self, name, connect_args=None, connect_kwargs=None):
self.name = name
self.connect_args = connect_args or tuple()
self.connect_kwargs = connect_kwargs or {}
self.lock = threading.Lock()
self._connection = None

def connection(self):
if self._connection is None:
self._connection = sqlite3.connect(
*self.connect_args, **self.connect_kwargs
)
return self._connection

def __repr__(self):
return "{} {} ({})".format(self.name, self._connection, self.lock)


class ConnectionGroup:
timeout = 2

def __init__(self, name, connect_args, connect_kwargs=None, limit=3):
self.name = name
self.connections = [
Connection(name, connect_args, connect_kwargs) for _ in range(limit)
]
self.limit = limit
self._semaphore = threading.Semaphore(value=limit)

@contextlib.contextmanager
def connection(self):
semaphore_aquired = False
reserved_connection = None
try:
semaphore_aquired = self._semaphore.acquire(timeout=self.timeout)
if not semaphore_aquired:
raise ConnectionTimeoutError(
"Timed out after {}s waiting for connection '{}'".format(
self.timeout, self.name
)
)
# Loop through connections attempting to aquire a lock
for connection in self.connections:
lock = connection.lock
if lock.acquire(False):
# We acquired the lock! use this one
reserved_connection = connection
break
else:
# If we get here, we failed to lock a connection even though
# the semaphore should have guaranteed it
raise ConnectionError(
"Failed to lock a connection despite the sempahore"
)
# We should have a connection now - yield it, then clean up locks
yield reserved_connection
finally:
reserved_connection.lock.release()
if semaphore_aquired:
self._semaphore.release()


class Pool:
def __init__(self, databases=None, max_connections_per_database=3):
self.max_connections_per_database = max_connections_per_database
self.databases = {}
self.connection_groups = {}
for key, value in (databases or {}).items():
self.add_database(key, value)

def add_database(self, name, filepath):
self.databases[name] = filepath
self.connection_groups[name] = ConnectionGroup(
name, [filepath], limit=self.max_connections_per_database
)

@contextlib.contextmanager
def connection(self, name):
with self.connection_groups[name].connection() as conn:
yield conn
65 changes: 65 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import threading
import time

import pytest

from datasette.database import Pool


def test_lock_connection():
pool = Pool({"one": ":memory:"})
with pool.connection("one") as conn:
assert conn.lock.locked()
assert not conn.lock.locked()


def test_connect_if_one_connection_is_locked():
pool = Pool({"one": ":memory:"})
connections = pool.connection_groups["one"].connections
assert 3 == len(connections)
# They should all start unlocked:
assert all(not c.lock.locked() for c in connections)
# Now lock one for the duration of this test
first_connection = connections[0]
try:
first_connection.lock.acquire()
# This should give us a different connection
with pool.connection("one") as conn:
assert conn is not first_connection
assert conn.lock.locked()
# There should be only one UNLOCKED connection now
assert 1 == len([c for c in connections if not c.lock.locked()])
finally:
first_connection.lock.release()
# At this point, all connections should be unlocked
assert 3 == len([c for c in connections if not c.lock.locked()])


def test_block_until_connection_is_released():
# If all connections are already in use, block until one is released
pool = Pool({"one": ":memory:"}, max_connections_per_database=1)
connections = pool.connection_groups["one"].connections
assert 1 == len(connections)

def block_connection(pool):
with pool.connection("one"):
time.sleep(0.05)

t = threading.Thread(target=block_connection, args=[pool])
t.start()
# Give thread time to grab the connection:
time.sleep(0.01)
# Thread should now have grabbed and locked a connection:
assert 1 == len([c for c in connections if c.lock.locked()])

start = time.time()
# Now we attempt to use the connection. This should block.
with pool.connection("one") as conn:
# At this point, over 0.02 seconds should have passed
assert (time.time() - start) > 0.02
assert conn.lock.locked()

# Ensure thread has run to completion before ending test:
t.join()
# Connections should all be unlocked at the end
assert all(not c.lock.locked() for c in connections)