Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A unified, robust and bug-free connection management interface for the ORM #1001

Merged
merged 31 commits into from
Mar 20, 2022
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c2a398b
- Initial working commit for refactored connection management.
blazing-gig Dec 5, 2021
c5c7d31
- Refactored test_init.py to conform to the new connection management…
blazing-gig Dec 5, 2021
1fd9dac
- Fixed linting errors (flake8 and mypy)
blazing-gig Dec 5, 2021
faa4d87
- Modified comment
blazing-gig Dec 5, 2021
79b187a
- Code cleanup
blazing-gig Dec 6, 2021
8184432
- Fixed style issues
blazing-gig Dec 7, 2021
d3e7403
- Fixed mysql breaking tests
blazing-gig Dec 19, 2021
212c8f7
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Dec 20, 2021
adacec3
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Jan 11, 2022
18912ef
- Merged latest head from develop and fixed breaking tests
blazing-gig Jan 13, 2022
9ed0e71
- Fixed style and lint issues
blazing-gig Jan 13, 2022
69d947e
- Fixed code quality issues (codacy)
blazing-gig Jan 13, 2022
3b3b2e5
- Refactored core and test files to use the new connection interface
blazing-gig Jan 13, 2022
a7f47ba
- Fixed style and lint errors
blazing-gig Jan 13, 2022
e8e529b
- Added unit tests for the new connection interface
blazing-gig Jan 18, 2022
ab7051e
- Added docs for the new connection interface
blazing-gig Jan 18, 2022
d99aeda
- Fixed codacy check
blazing-gig Jan 18, 2022
9adabf7
- Fixed codacy check
blazing-gig Jan 18, 2022
1edd813
- Fixed codacy check
blazing-gig Jan 18, 2022
d7e7048
- Fixed codacy check
blazing-gig Jan 18, 2022
c04d1eb
- Fixed codacy check
blazing-gig Jan 18, 2022
32f87e4
- Refactored Tortoise.close_connections to use the connections.close_…
blazing-gig Jan 19, 2022
fbfaff8
- Updated docs
blazing-gig Jan 19, 2022
0468d81
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Feb 16, 2022
522945e
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Feb 25, 2022
58d8843
- Removed current_transaction_map occurrences in code
blazing-gig Mar 5, 2022
092448b
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Mar 5, 2022
ee0a189
- Fixed breaking tests due to merge
blazing-gig Mar 16, 2022
8711b4f
- Fixed lint errors
blazing-gig Mar 16, 2022
7518a02
- Updated CHANGELOG.rst
blazing-gig Mar 18, 2022
eb6fb02
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Mar 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions docs/connections.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
.. _connections:

===========
Connections
===========

This document describes how to access the underlying connection object (:ref:`BaseDBAsyncClient<base_db_client>`) for the aliases defined
as part of the DB config passed to the :meth:`Tortoise.init<tortoise.Tortoise.init>` call.

Below is a simple code snippet which shows how the interface can be accessed:

.. code-block:: python3

# connections is a singleton instance of the ConnectionHandler class and serves as the
# entrypoint to access all connection management APIs.
from tortoise import connections


# Assume that this is the Tortoise configuration used
await Tortoise.init(
{
"connections": {
"default": {
"engine": "tortoise.backends.sqlite",
"credentials": {"file_path": "example.sqlite3"},
}
},
"apps": {
"events": {"models": ["__main__"], "default_connection": "default"}
},
}
)

conn: BaseDBAsyncClient = connections.get("default")
try:
await conn.execute_query('SELECT * FROM "event"')
except OperationalError:
print("Expected it to fail")

.. important::
The :ref:`tortoise.connection.ConnectionHandler<connection_handler>` class has been implemented with the singleton
pattern in mind and so when the ORM initializes, a singleton instance of this class
``tortoise.connection.connections`` is created automatically and lives in memory up until the lifetime of the app.
Any attempt to modify or override its behaviour at runtime is risky and not recommended.


Please refer to :ref:`this example<example_two_databases>` for a detailed demonstration of how this API can be used
in practice.


API Reference
===========

.. _connection_handler:

.. automodule:: tortoise.connection
:members:
:undoc-members:
1 change: 1 addition & 0 deletions docs/databases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ handle complex objects.
}
)

.. _base_db_client:

Base DB client
==============
Expand Down
1 change: 1 addition & 0 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Reference
functions
expressions
transactions
connections
exceptions
signals
migration
Expand Down
4 changes: 2 additions & 2 deletions examples/manual_sql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This example demonstrates executing manual SQL queries
"""
from tortoise import Tortoise, fields, run_async
from tortoise import Tortoise, connections, fields, run_async
from tortoise.models import Model
from tortoise.transactions import in_transaction

Expand All @@ -17,7 +17,7 @@ async def run():
await Tortoise.generate_schemas()

# Need to get a connection. Unless explicitly specified, the name should be 'default'
conn = Tortoise.get_connection("default")
conn = connections.get("default")

# Now we can execute queries in the normal autocommit mode
await conn.execute_query("INSERT INTO event (name) VALUES ('Foo')")
Expand Down
8 changes: 4 additions & 4 deletions examples/schema_create.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This example demonstrates SQL Schema generation for each DB type supported.
"""
from tortoise import Tortoise, fields, run_async
from tortoise import Tortoise, connections, fields, run_async
from tortoise.models import Model
from tortoise.utils import get_schema_sql

Expand Down Expand Up @@ -49,19 +49,19 @@ class Meta:
async def run():
print("SQLite:\n")
await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]})
sql = get_schema_sql(Tortoise.get_connection("default"), safe=False)
sql = get_schema_sql(connections.get("default"), safe=False)
print(sql)

print("\n\nMySQL:\n")
await Tortoise.init(db_url="mysql://root:@127.0.0.1:3306/", modules={"models": ["__main__"]})
sql = get_schema_sql(Tortoise.get_connection("default"), safe=False)
sql = get_schema_sql(connections.get("default"), safe=False)
print(sql)

print("\n\nPostgreSQL:\n")
await Tortoise.init(
db_url="postgres://postgres:@127.0.0.1:5432/", modules={"models": ["__main__"]}
)
sql = get_schema_sql(Tortoise.get_connection("default"), safe=False)
sql = get_schema_sql(connections.get("default"), safe=False)
print(sql)


Expand Down
6 changes: 3 additions & 3 deletions examples/two_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Key notes of this example is using db_route for Tortoise init
and explicitly declaring model apps in class Meta
"""
from tortoise import Tortoise, fields, run_async
from tortoise import Tortoise, connections, fields, run_async
from tortoise.exceptions import OperationalError
from tortoise.models import Model

Expand Down Expand Up @@ -73,8 +73,8 @@ async def run():
}
)
await Tortoise.generate_schemas()
client = Tortoise.get_connection("first")
second_client = Tortoise.get_connection("second")
client = connections.get("first")
second_client = connections.get("second")

tournament = await Tournament.create(name="Tournament")
await Event(name="Event", tournament_id=tournament.id).save()
Expand Down
4 changes: 2 additions & 2 deletions tests/backends/test_capabilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tortoise import Tortoise
from tortoise import connections
from tortoise.contrib import test


Expand All @@ -7,7 +7,7 @@ class TestCapabilities(test.TestCase):

async def asyncSetUp(self) -> None:
await super(TestCapabilities, self).asyncSetUp()
self.db = Tortoise.get_connection("models")
self.db = connections.get("models")
self.caps = self.db.capabilities

def test_str(self):
Expand Down
35 changes: 23 additions & 12 deletions tests/backends/test_connection_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@

import asyncpg

from tortoise import Tortoise
from tortoise import connections
from tortoise.contrib import test


class TestConnectionParams(test.TestCase):
class TestConnectionParams(test.SimpleTestCase):
async def asyncSetUp(self) -> None:
pass
await super().asyncSetUp()

async def asyncTearDown(self) -> None:
pass
await super().asyncTearDown()

async def test_mysql_connection_params(self):
with patch("asyncmy.create_pool", new=AsyncMock()) as mysql_connect:
await Tortoise._init_connections(
with patch(
"tortoise.backends.mysql.client.mysql.create_pool", new=AsyncMock()
) as mysql_connect:
await connections._init(
{
"models": {
"engine": "tortoise.backends.mysql",
Expand All @@ -32,6 +34,7 @@ async def test_mysql_connection_params(self):
},
False,
)
await connections.get("models").create_connection(with_db=True)

mysql_connect.assert_awaited_once_with( # nosec
autocommit=True,
Expand All @@ -49,8 +52,10 @@ async def test_mysql_connection_params(self):

async def test_asyncpg_connection_params(self):
try:
with patch("asyncpg.create_pool", new=AsyncMock()) as asyncpg_connect:
await Tortoise._init_connections(
with patch(
"tortoise.backends.asyncpg.client.asyncpg.create_pool", new=AsyncMock()
) as asyncpg_connect:
await connections._init(
{
"models": {
"engine": "tortoise.backends.asyncpg",
Expand All @@ -67,6 +72,7 @@ async def test_asyncpg_connection_params(self):
},
False,
)
await connections.get("models").create_connection(with_db=True)

asyncpg_connect.assert_awaited_once_with( # nosec
None,
Expand All @@ -88,8 +94,12 @@ async def test_asyncpg_connection_params(self):

async def test_psycopg_connection_params(self):
try:
with patch("psycopg_pool.AsyncConnectionPool.open", new=AsyncMock()) as psycopg_connect:
await Tortoise._init_connections(
with patch(
"tortoise.backends.psycopg.client.PsycopgClient.create_pool", new=AsyncMock()
) as patched_create_pool:
mocked_pool = AsyncMock()
patched_create_pool.return_value = mocked_pool
await connections._init(
{
"models": {
"engine": "tortoise.backends.psycopg",
Expand All @@ -106,8 +116,9 @@ async def test_psycopg_connection_params(self):
},
False,
)

psycopg_connect.assert_awaited_once_with( # nosec
await connections.get("models").create_connection(with_db=True)
patched_create_pool.assert_awaited_once()
mocked_pool.open.assert_awaited_once_with( # nosec
wait=True,
timeout=1,
)
Expand Down
5 changes: 3 additions & 2 deletions tests/backends/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ async def asyncSetUp(self):
async def asyncTearDown(self) -> None:
if Tortoise._inited:
await Tortoise._drop_databases()
await super().asyncTearDown()

async def test_bad_charset(self):
self.db_config["connections"]["models"]["credentials"]["charset"] = "terrible"
with self.assertRaisesRegex(ConnectionError, "Unknown charset"):
await Tortoise.init(self.db_config)
await Tortoise.init(self.db_config, _create_db=True)

async def test_ssl_true(self):
self.db_config["connections"]["models"]["credentials"]["ssl"] = True
try:
await Tortoise.init(self.db_config)
await Tortoise.init(self.db_config, _create_db=True)
except (ConnectionError, ssl.SSLError):
pass
else:
Expand Down
13 changes: 7 additions & 6 deletions tests/backends/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ssl

from tests.testmodels import Tournament
from tortoise import Tortoise
from tortoise import Tortoise, connections
from tortoise.contrib import test
from tortoise.exceptions import OperationalError

Expand All @@ -29,6 +29,7 @@ def is_asyncpg(self) -> bool:
async def asyncTearDown(self) -> None:
if Tortoise._inited:
await Tortoise._drop_databases()
await super().asyncTearDown()

async def test_schema(self):
if self.is_asyncpg:
Expand All @@ -42,20 +43,20 @@ async def test_schema(self):
with self.assertRaises(InvalidSchemaNameError):
await Tortoise.generate_schemas()

conn = Tortoise.get_connection("models")
conn = connections.get("models")
await conn.execute_script("CREATE SCHEMA mytestschema;")
await Tortoise.generate_schemas()

tournament = await Tournament.create(name="Test")
await Tortoise.close_connections()
await connections.close_all()

del self.db_config["connections"]["models"]["credentials"]["schema"]
await Tortoise.init(self.db_config)

with self.assertRaises(OperationalError):
await Tournament.filter(name="Test").first()

conn = Tortoise.get_connection("models")
conn = connections.get("models")
_, res = await conn.execute_query(
"SELECT id, name FROM mytestschema.tournament WHERE name='Test' LIMIT 1"
)
Expand All @@ -67,7 +68,7 @@ async def test_schema(self):
async def test_ssl_true(self):
self.db_config["connections"]["models"]["credentials"]["ssl"] = True
try:
await Tortoise.init(self.db_config)
await Tortoise.init(self.db_config, _create_db=True)
except (ConnectionError, ssl.SSLError):
pass
else:
Expand All @@ -91,7 +92,7 @@ async def test_application_name(self):
] = "mytest_application"
await Tortoise.init(self.db_config, _create_db=True)

conn = Tortoise.get_connection("models")
conn = connections.get("models")
_, res = await conn.execute_query(
"SELECT application_name FROM pg_stat_activity WHERE pid = pg_backend_pid()"
)
Expand Down
10 changes: 5 additions & 5 deletions tests/backends/test_reconnect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from tests.testmodels import Tournament
from tortoise import Tortoise
from tortoise import connections
from tortoise.contrib import test
from tortoise.transactions import in_transaction

Expand All @@ -9,11 +9,11 @@ class TestReconnect(test.IsolatedTestCase):
async def test_reconnect(self):
await Tournament.create(name="1")

await Tortoise._connections["models"]._expire_connections()
await connections.get("models")._expire_connections()

await Tournament.create(name="2")

await Tortoise._connections["models"]._expire_connections()
await connections.get("models")._expire_connections()

await Tournament.create(name="3")

Expand All @@ -26,12 +26,12 @@ async def test_reconnect_transaction_start(self):
async with in_transaction():
await Tournament.create(name="1")

await Tortoise._connections["models"]._expire_connections()
await connections.get("models")._expire_connections()

async with in_transaction():
await Tournament.create(name="2")

await Tortoise._connections["models"]._expire_connections()
await connections.get("models")._expire_connections()

async with in_transaction():
self.assertEqual([f"{a.id}:{a.name}" for a in await Tournament.all()], ["1:1", "2:2"])
4 changes: 2 additions & 2 deletions tests/contrib/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from tests.testmodels import IntFields
from tortoise import Tortoise
from tortoise import connections
from tortoise.contrib import test
from tortoise.contrib.mysql.functions import Rand
from tortoise.contrib.postgres.functions import Random as PostgresRandom
Expand All @@ -10,7 +10,7 @@ class TestFunction(test.TestCase):
async def asyncSetUp(self):
await super().asyncSetUp()
self.intfields = [await IntFields.create(intnum=val) for val in range(10)]
self.db = Tortoise.get_connection("models")
self.db = connections.get("models")

@test.requireCapability(dialect="mysql")
async def test_mysql_func_rand(self):
Expand Down
Loading