diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4205d5d63..86372f629 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,10 +13,22 @@ Changelog Added ^^^^^ - Added psycopg backend support. +- Added a new unified and robust connection management interface to access DB connections which includes support for + lazy connection creation and much more. For more details, + check out this `PR `_ Fixed ^^^^^ - Fix `bulk_create` doesn't work correctly with more than 1 update_fields. (#1046) - Fix `bulk_update` errors when setting null for a smallint column on postgres. (#1086) +Deprecated +^^^^^ +- Existing connection management interface and related public APIs which are deprecated: + - `Tortoise.get_connection` + - `Tortoise.close_connections` +Changed +^^^^^ +- Refactored `tortoise.transactions.get_connection` method to `tortoise.transactions._get_connection`. + Note that this method has now been marked **private to this module and is not part of the public API** 0.18.1 ------ diff --git a/docs/connections.rst b/docs/connections.rst new file mode 100644 index 000000000..344c89fec --- /dev/null +++ b/docs/connections.rst @@ -0,0 +1,58 @@ +.. _connections: + +=========== +Connections +=========== + +This document describes how to access the underlying connection object (:ref:`BaseDBAsyncClient`) for the aliases defined +as part of the DB config passed to the :meth:`Tortoise.init` call. + +Below is a simple code snippet which shows how the interface can be accessed: + +.. code-block:: python3 + + # connections is a singleton instance of the ConnectionHandler class and serves as the + # entrypoint to access all connection management APIs. + from tortoise import connections + + + # Assume that this is the Tortoise configuration used + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": "example.sqlite3"}, + } + }, + "apps": { + "events": {"models": ["__main__"], "default_connection": "default"} + }, + } + ) + + conn: BaseDBAsyncClient = connections.get("default") + try: + await conn.execute_query('SELECT * FROM "event"') + except OperationalError: + print("Expected it to fail") + +.. important:: + The :ref:`tortoise.connection.ConnectionHandler` class has been implemented with the singleton + pattern in mind and so when the ORM initializes, a singleton instance of this class + ``tortoise.connection.connections`` is created automatically and lives in memory up until the lifetime of the app. + Any attempt to modify or override its behaviour at runtime is risky and not recommended. + + +Please refer to :ref:`this example` for a detailed demonstration of how this API can be used +in practice. + + +API Reference +=========== + +.. _connection_handler: + +.. automodule:: tortoise.connection + :members: + :undoc-members: \ No newline at end of file diff --git a/docs/databases.rst b/docs/databases.rst index 2bff581df..4e3ac8725 100644 --- a/docs/databases.rst +++ b/docs/databases.rst @@ -214,6 +214,7 @@ handle complex objects. } ) +.. _base_db_client: Base DB client ============== diff --git a/docs/reference.rst b/docs/reference.rst index 67721f10b..fcfad2186 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -16,6 +16,7 @@ Reference functions expressions transactions + connections exceptions signals migration diff --git a/examples/manual_sql.py b/examples/manual_sql.py index 0205eed03..b3d46e036 100644 --- a/examples/manual_sql.py +++ b/examples/manual_sql.py @@ -1,7 +1,7 @@ """ This example demonstrates executing manual SQL queries """ -from tortoise import Tortoise, fields, run_async +from tortoise import Tortoise, connections, fields, run_async from tortoise.models import Model from tortoise.transactions import in_transaction @@ -17,7 +17,7 @@ async def run(): await Tortoise.generate_schemas() # Need to get a connection. Unless explicitly specified, the name should be 'default' - conn = Tortoise.get_connection("default") + conn = connections.get("default") # Now we can execute queries in the normal autocommit mode await conn.execute_query("INSERT INTO event (name) VALUES ('Foo')") diff --git a/examples/schema_create.py b/examples/schema_create.py index 96f7f91c7..4dcb47687 100644 --- a/examples/schema_create.py +++ b/examples/schema_create.py @@ -1,7 +1,7 @@ """ This example demonstrates SQL Schema generation for each DB type supported. """ -from tortoise import Tortoise, fields, run_async +from tortoise import Tortoise, connections, fields, run_async from tortoise.models import Model from tortoise.utils import get_schema_sql @@ -49,19 +49,19 @@ class Meta: async def run(): print("SQLite:\n") await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]}) - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) print(sql) print("\n\nMySQL:\n") await Tortoise.init(db_url="mysql://root:@127.0.0.1:3306/", modules={"models": ["__main__"]}) - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) print(sql) print("\n\nPostgreSQL:\n") await Tortoise.init( db_url="postgres://postgres:@127.0.0.1:5432/", modules={"models": ["__main__"]} ) - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) print(sql) diff --git a/examples/two_databases.py b/examples/two_databases.py index 88b8ed8fe..96d59d68b 100644 --- a/examples/two_databases.py +++ b/examples/two_databases.py @@ -8,7 +8,7 @@ Key notes of this example is using db_route for Tortoise init and explicitly declaring model apps in class Meta """ -from tortoise import Tortoise, fields, run_async +from tortoise import Tortoise, connections, fields, run_async from tortoise.exceptions import OperationalError from tortoise.models import Model @@ -73,8 +73,8 @@ async def run(): } ) await Tortoise.generate_schemas() - client = Tortoise.get_connection("first") - second_client = Tortoise.get_connection("second") + client = connections.get("first") + second_client = connections.get("second") tournament = await Tournament.create(name="Tournament") await Event(name="Event", tournament_id=tournament.id).save() diff --git a/tests/backends/test_capabilities.py b/tests/backends/test_capabilities.py index 5acb4a9c9..1a9c1be21 100644 --- a/tests/backends/test_capabilities.py +++ b/tests/backends/test_capabilities.py @@ -1,4 +1,4 @@ -from tortoise import Tortoise +from tortoise import connections from tortoise.contrib import test @@ -7,7 +7,7 @@ class TestCapabilities(test.TestCase): async def asyncSetUp(self) -> None: await super(TestCapabilities, self).asyncSetUp() - self.db = Tortoise.get_connection("models") + self.db = connections.get("models") self.caps = self.db.capabilities def test_str(self): diff --git a/tests/backends/test_connection_params.py b/tests/backends/test_connection_params.py index ef6c4590c..ef346760a 100644 --- a/tests/backends/test_connection_params.py +++ b/tests/backends/test_connection_params.py @@ -2,20 +2,22 @@ import asyncpg -from tortoise import Tortoise +from tortoise import connections from tortoise.contrib import test -class TestConnectionParams(test.TestCase): +class TestConnectionParams(test.SimpleTestCase): async def asyncSetUp(self) -> None: - pass + await super().asyncSetUp() async def asyncTearDown(self) -> None: - pass + await super().asyncTearDown() async def test_mysql_connection_params(self): - with patch("asyncmy.create_pool", new=AsyncMock()) as mysql_connect: - await Tortoise._init_connections( + with patch( + "tortoise.backends.mysql.client.mysql.create_pool", new=AsyncMock() + ) as mysql_connect: + await connections._init( { "models": { "engine": "tortoise.backends.mysql", @@ -32,6 +34,7 @@ async def test_mysql_connection_params(self): }, False, ) + await connections.get("models").create_connection(with_db=True) mysql_connect.assert_awaited_once_with( # nosec autocommit=True, @@ -49,8 +52,10 @@ async def test_mysql_connection_params(self): async def test_asyncpg_connection_params(self): try: - with patch("asyncpg.create_pool", new=AsyncMock()) as asyncpg_connect: - await Tortoise._init_connections( + with patch( + "tortoise.backends.asyncpg.client.asyncpg.create_pool", new=AsyncMock() + ) as asyncpg_connect: + await connections._init( { "models": { "engine": "tortoise.backends.asyncpg", @@ -67,6 +72,7 @@ async def test_asyncpg_connection_params(self): }, False, ) + await connections.get("models").create_connection(with_db=True) asyncpg_connect.assert_awaited_once_with( # nosec None, @@ -88,8 +94,12 @@ async def test_asyncpg_connection_params(self): async def test_psycopg_connection_params(self): try: - with patch("psycopg_pool.AsyncConnectionPool.open", new=AsyncMock()) as psycopg_connect: - await Tortoise._init_connections( + with patch( + "tortoise.backends.psycopg.client.PsycopgClient.create_pool", new=AsyncMock() + ) as patched_create_pool: + mocked_pool = AsyncMock() + patched_create_pool.return_value = mocked_pool + await connections._init( { "models": { "engine": "tortoise.backends.psycopg", @@ -106,8 +116,9 @@ async def test_psycopg_connection_params(self): }, False, ) - - psycopg_connect.assert_awaited_once_with( # nosec + await connections.get("models").create_connection(with_db=True) + patched_create_pool.assert_awaited_once() + mocked_pool.open.assert_awaited_once_with( # nosec wait=True, timeout=1, ) diff --git a/tests/backends/test_mysql.py b/tests/backends/test_mysql.py index cb5babc02..4f5b427df 100644 --- a/tests/backends/test_mysql.py +++ b/tests/backends/test_mysql.py @@ -18,16 +18,17 @@ async def asyncSetUp(self): async def asyncTearDown(self) -> None: if Tortoise._inited: await Tortoise._drop_databases() + await super().asyncTearDown() async def test_bad_charset(self): self.db_config["connections"]["models"]["credentials"]["charset"] = "terrible" with self.assertRaisesRegex(ConnectionError, "Unknown charset"): - await Tortoise.init(self.db_config) + await Tortoise.init(self.db_config, _create_db=True) async def test_ssl_true(self): self.db_config["connections"]["models"]["credentials"]["ssl"] = True try: - await Tortoise.init(self.db_config) + await Tortoise.init(self.db_config, _create_db=True) except (ConnectionError, ssl.SSLError): pass else: diff --git a/tests/backends/test_postgres.py b/tests/backends/test_postgres.py index 3c5d19fc4..25b52d60f 100644 --- a/tests/backends/test_postgres.py +++ b/tests/backends/test_postgres.py @@ -4,7 +4,7 @@ import ssl from tests.testmodels import Tournament -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.exceptions import OperationalError @@ -29,6 +29,7 @@ def is_asyncpg(self) -> bool: async def asyncTearDown(self) -> None: if Tortoise._inited: await Tortoise._drop_databases() + await super().asyncTearDown() async def test_schema(self): if self.is_asyncpg: @@ -42,12 +43,12 @@ async def test_schema(self): with self.assertRaises(InvalidSchemaNameError): await Tortoise.generate_schemas() - conn = Tortoise.get_connection("models") + conn = connections.get("models") await conn.execute_script("CREATE SCHEMA mytestschema;") await Tortoise.generate_schemas() tournament = await Tournament.create(name="Test") - await Tortoise.close_connections() + await connections.close_all() del self.db_config["connections"]["models"]["credentials"]["schema"] await Tortoise.init(self.db_config) @@ -55,7 +56,7 @@ async def test_schema(self): with self.assertRaises(OperationalError): await Tournament.filter(name="Test").first() - conn = Tortoise.get_connection("models") + conn = connections.get("models") _, res = await conn.execute_query( "SELECT id, name FROM mytestschema.tournament WHERE name='Test' LIMIT 1" ) @@ -67,7 +68,7 @@ async def test_schema(self): async def test_ssl_true(self): self.db_config["connections"]["models"]["credentials"]["ssl"] = True try: - await Tortoise.init(self.db_config) + await Tortoise.init(self.db_config, _create_db=True) except (ConnectionError, ssl.SSLError): pass else: @@ -91,7 +92,7 @@ async def test_application_name(self): ] = "mytest_application" await Tortoise.init(self.db_config, _create_db=True) - conn = Tortoise.get_connection("models") + conn = connections.get("models") _, res = await conn.execute_query( "SELECT application_name FROM pg_stat_activity WHERE pid = pg_backend_pid()" ) diff --git a/tests/backends/test_reconnect.py b/tests/backends/test_reconnect.py index 75b2eb04a..b0e07ef6b 100644 --- a/tests/backends/test_reconnect.py +++ b/tests/backends/test_reconnect.py @@ -1,5 +1,5 @@ from tests.testmodels import Tournament -from tortoise import Tortoise +from tortoise import connections from tortoise.contrib import test from tortoise.transactions import in_transaction @@ -9,11 +9,11 @@ class TestReconnect(test.IsolatedTestCase): async def test_reconnect(self): await Tournament.create(name="1") - await Tortoise._connections["models"]._expire_connections() + await connections.get("models")._expire_connections() await Tournament.create(name="2") - await Tortoise._connections["models"]._expire_connections() + await connections.get("models")._expire_connections() await Tournament.create(name="3") @@ -26,12 +26,12 @@ async def test_reconnect_transaction_start(self): async with in_transaction(): await Tournament.create(name="1") - await Tortoise._connections["models"]._expire_connections() + await connections.get("models")._expire_connections() async with in_transaction(): await Tournament.create(name="2") - await Tortoise._connections["models"]._expire_connections() + await connections.get("models")._expire_connections() async with in_transaction(): self.assertEqual([f"{a.id}:{a.name}" for a in await Tournament.all()], ["1:1", "2:2"]) diff --git a/tests/contrib/test_functions.py b/tests/contrib/test_functions.py index 5abf398f1..6f4e2c1ff 100644 --- a/tests/contrib/test_functions.py +++ b/tests/contrib/test_functions.py @@ -1,5 +1,5 @@ from tests.testmodels import IntFields -from tortoise import Tortoise +from tortoise import connections from tortoise.contrib import test from tortoise.contrib.mysql.functions import Rand from tortoise.contrib.postgres.functions import Random as PostgresRandom @@ -10,7 +10,7 @@ class TestFunction(test.TestCase): async def asyncSetUp(self): await super().asyncSetUp() self.intfields = [await IntFields.create(intnum=val) for val in range(10)] - self.db = Tortoise.get_connection("models") + self.db = connections.get("models") @test.requireCapability(dialect="mysql") async def test_mysql_func_rand(self): diff --git a/tests/model_setup/test__models__.py b/tests/model_setup/test__models__.py index 77c64ffa0..b3ed1727b 100644 --- a/tests/model_setup/test__models__.py +++ b/tests/model_setup/test__models__.py @@ -5,7 +5,7 @@ import re from unittest.mock import AsyncMock, patch -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.exceptions import ConfigurationError from tortoise.utils import get_schema_sql @@ -16,7 +16,6 @@ async def asyncSetUp(self): await super().asyncSetUp() try: Tortoise.apps = {} - Tortoise._connections = {} Tortoise._inited = False except ConfigurationError: pass @@ -28,7 +27,6 @@ async def asyncSetUp(self): ] async def asyncTearDown(self) -> None: - Tortoise._connections = {} await Tortoise._reset_apps() async def init_for(self, module: str, safe=False) -> None: @@ -48,7 +46,7 @@ async def init_for(self, module: str, safe=False) -> None: "apps": {"models": {"models": [module], "default_connection": "default"}}, } ) - self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split(";\n") + self.sqls = get_schema_sql(connections.get("default"), safe).split(";\n") def get_sql(self, text: str) -> str: return str(re.sub(r"[ \t\n\r]+", " ", [sql for sql in self.sqls if text in sql][0])) diff --git a/tests/model_setup/test_bad_relation_reference.py b/tests/model_setup/test_bad_relation_reference.py index 5d7c25d9a..7fcf721e3 100644 --- a/tests/model_setup/test_bad_relation_reference.py +++ b/tests/model_setup/test_bad_relation_reference.py @@ -5,16 +5,15 @@ class TestBadReleationReferenceErrors(test.SimpleTestCase): async def asyncSetUp(self): + await super().asyncSetUp() try: Tortoise.apps = {} - Tortoise._connections = {} Tortoise._inited = False except ConfigurationError: pass Tortoise._inited = False async def asyncTearDown(self) -> None: - await Tortoise.close_connections() await Tortoise._reset_apps() await super(TestBadReleationReferenceErrors, self).asyncTearDown() diff --git a/tests/model_setup/test_init.py b/tests/model_setup/test_init.py index 0fe9f838c..26d4a8682 100644 --- a/tests/model_setup/test_init.py +++ b/tests/model_setup/test_init.py @@ -1,6 +1,6 @@ import os -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.exceptions import ConfigurationError @@ -10,14 +10,12 @@ async def asyncSetUp(self): await super().asyncSetUp() try: Tortoise.apps = {} - Tortoise._connections = {} Tortoise._inited = False except ConfigurationError: pass Tortoise._inited = False async def asyncTearDown(self) -> None: - await Tortoise.close_connections() await Tortoise._reset_apps() await super(TestInitErrors, self).asyncTearDown() @@ -36,7 +34,7 @@ async def test_basic_init(self): } ) self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) + self.assertIsNotNone(connections.get("default")) async def test_empty_modules_init(self): with self.assertWarnsRegex(RuntimeWarning, 'Module "tests.model_setup" has no models'): @@ -184,7 +182,11 @@ async def test_nonpk_id(self): ) async def test_unknown_connection(self): - with self.assertRaisesRegex(ConfigurationError, 'Unknown connection "fioop"'): + with self.assertRaisesRegex( + ConfigurationError, + "Unable to get db settings for alias 'fioop'. Please " + "check if the config dict contains this alias and try again", + ): await Tortoise.init( { "connections": { @@ -218,7 +220,7 @@ async def test_default_connection_init(self): } ) self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) + self.assertIsNotNone(connections.get("default")) async def test_db_url_init(self): await Tortoise.init( @@ -230,14 +232,14 @@ async def test_db_url_init(self): } ) self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) + self.assertIsNotNone(connections.get("default")) async def test_shorthand_init(self): await Tortoise.init( db_url=f"sqlite://{':memory:'}", modules={"models": ["tests.testmodels"]} ) self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) + self.assertIsNotNone(connections.get("default")) async def test_init_wrong_connection_engine(self): with self.assertRaisesRegex(ImportError, "tortoise.backends.test"): @@ -326,13 +328,13 @@ async def test_init_config_file_wrong_extension(self): async def test_init_json_file(self): await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.json") self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) + self.assertIsNotNone(connections.get("default")) @test.skipIf(os.name == "nt", "path issue on Windows") async def test_init_yaml_file(self): await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.yaml") self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) + self.assertIsNotNone(connections.get("default")) async def test_generate_schema_without_init(self): with self.assertRaisesRegex( diff --git a/tests/schema/test_generate_schema.py b/tests/schema/test_generate_schema.py index 3b0eaa616..34f6c65a3 100644 --- a/tests/schema/test_generate_schema.py +++ b/tests/schema/test_generate_schema.py @@ -2,7 +2,7 @@ import re from unittest.mock import AsyncMock, patch -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.exceptions import ConfigurationError from tortoise.utils import get_schema_sql @@ -13,7 +13,6 @@ async def asyncSetUp(self): await super().asyncSetUp() try: Tortoise.apps = {} - Tortoise._connections = {} Tortoise._inited = False except ConfigurationError: pass @@ -25,11 +24,8 @@ async def asyncSetUp(self): ] async def asyncTearDown(self) -> None: - for connection in Tortoise._connections.values(): - await connection.close() - - Tortoise._connections = {} await Tortoise._reset_apps() + await super().asyncTearDown() async def init_for(self, module: str, safe=False) -> None: with patch( @@ -46,7 +42,7 @@ async def init_for(self, module: str, safe=False) -> None: "apps": {"models": {"models": [module], "default_connection": "default"}}, } ) - self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split(";\n") + self.sqls = get_schema_sql(connections.get("default"), safe).split(";\n") def get_sql(self, text: str) -> str: return re.sub(r"[ \t\n\r]+", " ", " ".join([sql for sql in self.sqls if text in sql])) @@ -141,7 +137,7 @@ async def test_table_and_row_comment_generation(self): async def test_schema_no_db_constraint(self): self.maxDiff = None await self.init_for("tests.schema.models_no_db_constraint") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE "team" ( @@ -181,7 +177,7 @@ async def test_schema_no_db_constraint(self): async def test_schema(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), """ @@ -268,7 +264,7 @@ async def test_schema(self): async def test_schema_safe(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=True) + sql = get_schema_sql(connections.get("default"), safe=True) self.assertEqual( sql.strip(), """ @@ -355,7 +351,7 @@ async def test_schema_safe(self): async def test_m2m_no_auto_create(self): self.maxDiff = None await self.init_for("tests.schema.models_no_auto_create_m2m") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE "team" ( @@ -419,7 +415,7 @@ async def init_for(self, module: str, safe=False) -> None: "apps": {"models": {"models": [module], "default_connection": "default"}}, } ) - self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split("; ") + self.sqls = get_schema_sql(connections.get("default"), safe).split("; ") except ImportError: raise test.SkipTest("aiomysql not installed") @@ -462,7 +458,7 @@ async def test_table_and_row_comment_generation(self): async def test_schema_no_db_constraint(self): self.maxDiff = None await self.init_for("tests.schema.models_no_db_constraint") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE `team` ( @@ -502,7 +498,7 @@ async def test_schema_no_db_constraint(self): async def test_schema(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), """ @@ -601,7 +597,7 @@ async def test_schema(self): async def test_schema_safe(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=True) + sql = get_schema_sql(connections.get("default"), safe=True) self.assertEqual( sql.strip(), @@ -700,7 +696,7 @@ async def test_schema_safe(self): async def test_index_safe(self): await self.init_for("tests.schema.models_mysql_index") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=True) + sql = get_schema_sql(connections.get("default"), safe=True) self.assertEqual( sql, """CREATE TABLE IF NOT EXISTS `index` ( @@ -714,7 +710,7 @@ async def test_index_safe(self): async def test_index_unsafe(self): await self.init_for("tests.schema.models_mysql_index") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql, """CREATE TABLE `index` ( @@ -729,7 +725,7 @@ async def test_index_unsafe(self): async def test_m2m_no_auto_create(self): self.maxDiff = None await self.init_for("tests.schema.models_no_auto_create_m2m") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE `team` ( @@ -802,7 +798,7 @@ async def test_table_and_row_comment_generation(self): async def test_schema_no_db_constraint(self): self.maxDiff = None await self.init_for("tests.schema.models_no_db_constraint") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE "team" ( @@ -852,7 +848,7 @@ async def test_schema_no_db_constraint(self): async def test_schema(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), """ @@ -954,7 +950,7 @@ async def test_schema(self): async def test_schema_safe(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=True) + sql = get_schema_sql(connections.get("default"), safe=True) self.assertEqual( sql.strip(), """ @@ -1055,7 +1051,7 @@ async def test_schema_safe(self): async def test_index_unsafe(self): await self.init_for("tests.schema.models_postgres_index") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql, """CREATE TABLE "index" ( @@ -1077,7 +1073,7 @@ async def test_index_unsafe(self): async def test_index_safe(self): await self.init_for("tests.schema.models_postgres_index") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=True) + sql = get_schema_sql(connections.get("default"), safe=True) self.assertEqual( sql, """CREATE TABLE IF NOT EXISTS "index" ( @@ -1100,7 +1096,7 @@ async def test_index_safe(self): async def test_m2m_no_auto_create(self): self.maxDiff = None await self.init_for("tests.schema.models_no_auto_create_m2m") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE "team" ( @@ -1172,7 +1168,7 @@ async def init_for(self, module: str, safe=False) -> None: "apps": {"models": {"models": [module], "default_connection": "default"}}, } ) - self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split("; ") + self.sqls = get_schema_sql(connections.get("default"), safe).split("; ") except ImportError: raise test.SkipTest("asyncpg not installed") @@ -1198,6 +1194,6 @@ async def init_for(self, module: str, safe=False) -> None: "apps": {"models": {"models": [module], "default_connection": "default"}}, } ) - self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split("; ") + self.sqls = get_schema_sql(connections.get("default"), safe).split("; ") except ImportError: raise test.SkipTest("psycopg not installed") diff --git a/tests/test_case_when.py b/tests/test_case_when.py index 2ec15b97b..fa041d7e4 100644 --- a/tests/test_case_when.py +++ b/tests/test_case_when.py @@ -1,5 +1,5 @@ from tests.testmodels import IntFields -from tortoise import Tortoise +from tortoise import connections from tortoise.contrib import test from tortoise.expressions import Case, F, Q, When from tortoise.functions import Coalesce @@ -9,7 +9,7 @@ class TestCaseWhen(test.TestCase): async def asyncSetUp(self): await super().asyncSetUp() self.intfields = [await IntFields.create(intnum=val) for val in range(10)] - self.db = Tortoise.get_connection("models") + self.db = connections.get("models") async def test_single_when(self): category = Case(When(intnum__gte=8, then="big"), default="default") diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 000000000..7f6841330 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,280 @@ +from contextvars import ContextVar +from unittest.mock import AsyncMock, Mock, PropertyMock, call, patch + +from tortoise import BaseDBAsyncClient, ConfigurationError +from tortoise.connection import ConnectionHandler +from tortoise.contrib.test import SimpleTestCase + + +class TestConnections(SimpleTestCase): + def setUp(self) -> None: + self.conn_handler = ConnectionHandler() + + def test_init_constructor(self): + self.assertIsNone(self.conn_handler._db_config) + self.assertFalse(self.conn_handler._create_db) + + @patch("tortoise.connection.ConnectionHandler._init_connections") + async def test_init(self, mocked_init_connections: AsyncMock): + db_config = {"default": {"HOST": "some_host", "PORT": "1234"}} + await self.conn_handler._init(db_config, True) + mocked_init_connections.assert_awaited_once() + self.assertEqual(db_config, self.conn_handler._db_config) + self.assertTrue(self.conn_handler._create_db) + + def test_db_config_present(self): + self.conn_handler._db_config = {"default": {"HOST": "some_host", "PORT": "1234"}} + self.assertEqual(self.conn_handler.db_config, self.conn_handler._db_config) + + def test_db_config_not_present(self): + err_msg = ( + "DB configuration not initialised. Make sure to call " + "Tortoise.init with a valid configuration before attempting " + "to create connections." + ) + with self.assertRaises(ConfigurationError, msg=err_msg): + _ = self.conn_handler.db_config + + @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) + def test_get_storage(self, mocked_conn_storage: Mock): + expected_ret_val = {"default": BaseDBAsyncClient("default")} + mocked_conn_storage.get.return_value = expected_ret_val + ret_val = self.conn_handler._get_storage() + self.assertDictEqual(ret_val, expected_ret_val) + + @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) + def test_set_storage(self, mocked_conn_storage: Mock): + mocked_conn_storage.set.return_value = "blah" + new_storage = {"default": BaseDBAsyncClient("default")} + ret_val = self.conn_handler._set_storage(new_storage) + mocked_conn_storage.set.assert_called_once_with(new_storage) + self.assertEqual(ret_val, mocked_conn_storage.set.return_value) + + @patch("tortoise.connection.ConnectionHandler._get_storage") + @patch("tortoise.connection.copy") + def test_copy_storage(self, mocked_copy: Mock, mocked_get_storage: Mock): + expected_ret_value = {"default": BaseDBAsyncClient("default")} + mocked_get_storage.return_value = expected_ret_value + mocked_copy.return_value = expected_ret_value.copy() + ret_val = self.conn_handler._copy_storage() + mocked_get_storage.assert_called_once() + mocked_copy.assert_called_once_with(mocked_get_storage.return_value) + self.assertDictEqual(ret_val, expected_ret_value) + self.assertNotEqual(id(expected_ret_value), id(ret_val)) + + @patch("tortoise.connection.ConnectionHandler._get_storage") + def test_clear_storage(self, mocked_get_storage: Mock): + self.conn_handler._clear_storage() + mocked_get_storage.assert_called_once() + mocked_get_storage.return_value.clear.assert_called_once() + + @patch("tortoise.connection.importlib.import_module") + def test_discover_client_class_proper_impl(self, mocked_import_module: Mock): + mocked_import_module.return_value = Mock(client_class="some_class") + client_class = self.conn_handler._discover_client_class("blah") + mocked_import_module.assert_called_once_with("blah") + self.assertEqual(client_class, "some_class") + + @patch("tortoise.connection.importlib.import_module") + def test_discover_client_class_improper_impl(self, mocked_import_module: Mock): + del mocked_import_module.return_value.client_class + engine = "some_engine" + with self.assertRaises( + ConfigurationError, msg=f'Backend for engine "{engine}" does not implement db client' + ): + _ = self.conn_handler._discover_client_class(engine) + + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + def test_get_db_info_present(self, mocked_db_config: Mock): + expected_ret_val = {"HOST": "some_host", "PORT": "1234"} + mocked_db_config.return_value = {"default": expected_ret_val} + ret_val = self.conn_handler._get_db_info("default") + self.assertEqual(ret_val, expected_ret_val) + + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + def test_get_db_info_not_present(self, mocked_db_config: Mock): + mocked_db_config.return_value = {"default": {"HOST": "some_host", "PORT": "1234"}} + conn_alias = "blah" + with self.assertRaises( + ConfigurationError, + msg=f"Unable to get db settings for alias '{conn_alias}'. Please " + f"check if the config dict contains this alias and try again", + ): + _ = self.conn_handler._get_db_info(conn_alias) + + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + @patch("tortoise.connection.ConnectionHandler.get") + async def test_init_connections_no_db_create(self, mocked_get: Mock, mocked_db_config: Mock): + conn_1, conn_2 = AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient) + mocked_get.side_effect = [conn_1, conn_2] + mocked_db_config.return_value = { + "default": {"HOST": "some_host", "PORT": "1234"}, + "other": {"HOST": "some_other_host", "PORT": "1234"}, + } + await self.conn_handler._init_connections() + mocked_db_config.assert_called_once() + mocked_get.assert_has_calls([call("default"), call("other")], any_order=True) + conn_1.db_create.assert_not_awaited() + conn_2.db_create.assert_not_awaited() + + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + @patch("tortoise.connection.ConnectionHandler.get") + async def test_init_connections_db_create(self, mocked_get: Mock, mocked_db_config: Mock): + self.conn_handler._create_db = True + conn_1, conn_2 = AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient) + mocked_get.side_effect = [conn_1, conn_2] + mocked_db_config.return_value = { + "default": {"HOST": "some_host", "PORT": "1234"}, + "other": {"HOST": "some_other_host", "PORT": "1234"}, + } + await self.conn_handler._init_connections() + mocked_db_config.assert_called_once() + mocked_get.assert_has_calls([call("default"), call("other")], any_order=True) + conn_1.db_create.assert_awaited_once() + conn_2.db_create.assert_awaited_once() + + @patch("tortoise.connection.ConnectionHandler._get_db_info") + @patch("tortoise.connection.expand_db_url") + @patch("tortoise.connection.ConnectionHandler._discover_client_class") + def test_create_connection_db_info_str( + self, + mocked_discover_client_class: Mock, + mocked_expand_db_url: Mock, + mocked_get_db_info: Mock, + ): + alias = "default" + mocked_get_db_info.return_value = "some_db_url" + mocked_expand_db_url.return_value = { + "engine": "some_engine", + "credentials": {"cred_key": "some_val"}, + } + expected_client_class = Mock(return_value="some_connection") + mocked_discover_client_class.return_value = expected_client_class + expected_db_params = {"cred_key": "some_val", "connection_name": alias} + + ret_val = self.conn_handler._create_connection(alias) + + mocked_get_db_info.assert_called_once_with(alias) + mocked_expand_db_url.assert_called_once_with("some_db_url") + mocked_discover_client_class.assert_called_once_with("some_engine") + expected_client_class.assert_called_once_with(**expected_db_params) + self.assertEqual(ret_val, "some_connection") + + @patch("tortoise.connection.ConnectionHandler._get_db_info") + @patch("tortoise.connection.expand_db_url") + @patch("tortoise.connection.ConnectionHandler._discover_client_class") + def test_create_connection_db_info_not_str( + self, + mocked_discover_client_class: Mock, + mocked_expand_db_url: Mock, + mocked_get_db_info: Mock, + ): + alias = "default" + mocked_get_db_info.return_value = { + "engine": "some_engine", + "credentials": {"cred_key": "some_val"}, + } + expected_client_class = Mock(return_value="some_connection") + mocked_discover_client_class.return_value = expected_client_class + expected_db_params = {"cred_key": "some_val", "connection_name": alias} + + ret_val = self.conn_handler._create_connection(alias) + + mocked_get_db_info.assert_called_once_with(alias) + mocked_expand_db_url.assert_not_called() + mocked_discover_client_class.assert_called_once_with("some_engine") + expected_client_class.assert_called_once_with(**expected_db_params) + self.assertEqual(ret_val, "some_connection") + + @patch("tortoise.connection.ConnectionHandler._get_storage") + @patch("tortoise.connection.ConnectionHandler._create_connection") + def test_get_alias_present(self, mocked_create_connection: Mock, mocked_get_storage: Mock): + mocked_get_storage.return_value = {"default": "some_connection"} + ret_val = self.conn_handler.get("default") + mocked_get_storage.assert_called_once() + mocked_create_connection.assert_not_called() + self.assertEqual(ret_val, "some_connection") + + @patch("tortoise.connection.ConnectionHandler._get_storage") + @patch("tortoise.connection.ConnectionHandler._create_connection") + def test_get_alias_not_present(self, mocked_create_connection: Mock, mocked_get_storage: Mock): + mocked_get_storage.return_value = {"default": "some_connection"} + expected_final_dict = {**mocked_get_storage.return_value, "other": "some_other_connection"} + mocked_create_connection.return_value = "some_other_connection" + ret_val = self.conn_handler.get("other") + mocked_get_storage.assert_called_once() + mocked_create_connection.assert_called_once_with("other") + self.assertEqual(ret_val, "some_other_connection") + self.assertDictEqual(mocked_get_storage.return_value, expected_final_dict) + + @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) + @patch("tortoise.connection.ConnectionHandler._copy_storage") + def test_set(self, mocked_copy_storage: Mock, mocked_conn_storage: Mock): + mocked_copy_storage.return_value = {} + expected_storage = {"default": "some_conn"} + mocked_conn_storage.set.return_value = "some_token" + ret_val = self.conn_handler.set("default", "some_conn") # type: ignore + mocked_copy_storage.assert_called_once() + self.assertEqual(ret_val, mocked_conn_storage.set.return_value) + self.assertDictEqual(expected_storage, mocked_copy_storage.return_value) + + @patch("tortoise.connection.ConnectionHandler._get_storage") + def test_discard(self, mocked_get_storage: Mock): + mocked_get_storage.return_value = {"default": "some_conn"} + ret_val = self.conn_handler.discard("default") + self.assertEqual(ret_val, "some_conn") + self.assertDictEqual({}, mocked_get_storage.return_value) + + @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) + @patch("tortoise.connection.ConnectionHandler._get_storage") + def test_reset(self, mocked_get_storage: Mock, mocked_conn_storage: Mock): + first_config = {"other": "some_other_conn", "default": "diff_conn"} + second_config = {"default": "some_conn"} + mocked_get_storage.side_effect = [first_config, second_config] + final_storage = {"default": "some_conn", "other": "some_other_conn"} + self.conn_handler.reset("some_token") # type: ignore + mocked_get_storage.assert_has_calls([call(), call()]) + mocked_conn_storage.reset.assert_called_once_with("some_token") + self.assertDictEqual(final_storage, second_config) + + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + @patch("tortoise.connection.ConnectionHandler.get") + def test_all(self, mocked_get: Mock, mocked_db_config: Mock): + db_config = {"default": "some_conn", "other": "some_other_conn"} + + def side_effect_callable(alias): + return db_config[alias] + + mocked_get.side_effect = side_effect_callable + mocked_db_config.return_value = db_config + expected_result = ["some_conn", "some_other_conn"] + ret_val = self.conn_handler.all() + mocked_db_config.assert_called_once() + mocked_get.assert_has_calls([call("default"), call("other")], any_order=True) + self.assertEqual(ret_val, expected_result) + + @patch("tortoise.connection.ConnectionHandler.all") + @patch("tortoise.connection.ConnectionHandler.discard") + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + async def test_close_all_with_discard( + self, mocked_db_config: Mock, mocked_discard: Mock, mocked_all: Mock + ): + all_conn = [AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient)] + db_config = {"default": "some_config", "other": "some_other_config"} + mocked_all.return_value = all_conn + mocked_db_config.return_value = db_config + await self.conn_handler.close_all() + mocked_all.assert_called_once() + mocked_db_config.assert_called_once() + for mock_obj in all_conn: + mock_obj.close.assert_awaited_once() + mocked_discard.assert_has_calls([call("default"), call("other")], any_order=True) + + @patch("tortoise.connection.ConnectionHandler.all") + async def test_close_all_without_discard(self, mocked_all: Mock): + all_conn = [AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient)] + mocked_all.return_value = all_conn + await self.conn_handler.close_all(discard=False) + mocked_all.assert_called_once() + for mock_obj in all_conn: + mock_obj.close.assert_awaited_once() diff --git a/tests/test_manual_sql.py b/tests/test_manual_sql.py index eeac16078..bfbf5cc95 100644 --- a/tests/test_manual_sql.py +++ b/tests/test_manual_sql.py @@ -1,11 +1,11 @@ -from tortoise import Tortoise +from tortoise import connections from tortoise.contrib import test from tortoise.transactions import in_transaction class TestManualSQL(test.TruncationTestCase): async def test_simple_insert(self): - conn = Tortoise.get_connection("models") + conn = connections.get("models") await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") self.assertEqual( await conn.execute_query_dict("SELECT name FROM author"), [{"name": "Foo"}] @@ -15,7 +15,7 @@ async def test_in_transaction(self): async with in_transaction() as conn: await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") - conn = Tortoise.get_connection("models") + conn = connections.get("models") self.assertEqual( await conn.execute_query_dict("SELECT name FROM author"), [{"name": "Foo"}] ) @@ -29,7 +29,7 @@ async def test_in_transaction_exception(self): except ValueError: pass - conn = Tortoise.get_connection("models") + conn = connections.get("models") self.assertEqual(await conn.execute_query_dict("SELECT name FROM author"), []) @test.requireCapability(supports_transactions=True) @@ -38,7 +38,7 @@ async def test_in_transaction_rollback(self): await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") await conn.rollback() - conn = Tortoise.get_connection("models") + conn = connections.get("models") self.assertEqual(await conn.execute_query_dict("SELECT name FROM author"), []) async def test_in_transaction_commit(self): @@ -50,7 +50,7 @@ async def test_in_transaction_commit(self): except ValueError: pass - conn = Tortoise.get_connection("models") + conn = connections.get("models") self.assertEqual( await conn.execute_query_dict("SELECT name FROM author"), [{"name": "Foo"}] ) diff --git a/tests/test_queryset.py b/tests/test_queryset.py index 66a58d0ae..5d56e7584 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -1,5 +1,5 @@ from tests.testmodels import Event, IntFields, MinRelation, Node, Reporter, Tournament, Tree -from tortoise import Tortoise +from tortoise import connections from tortoise.contrib import test from tortoise.exceptions import ( DoesNotExist, @@ -19,7 +19,7 @@ async def asyncSetUp(self): await super().asyncSetUp() # Build large dataset self.intfields = [await IntFields.create(intnum=val) for val in range(10, 100, 3)] - self.db = Tortoise.get_connection("models") + self.db = connections.get("models") async def test_all_count(self): self.assertEqual(await IntFields.all().count(), 30) diff --git a/tests/test_two_databases.py b/tests/test_two_databases.py index 41b7a79ad..cbf95a601 100644 --- a/tests/test_two_databases.py +++ b/tests/test_two_databases.py @@ -1,5 +1,5 @@ from tests.testmodels import Event, EventTwo, TeamTwo, Tournament -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.exceptions import OperationalError, ParamsError from tortoise.transactions import in_transaction @@ -18,8 +18,8 @@ async def asyncSetUp(self): } await Tortoise.init(merged_config, _create_db=True) await Tortoise.generate_schemas() - self.db = Tortoise.get_connection("models") - self.second_db = Tortoise.get_connection("events") + self.db = connections.get("models") + self.second_db = connections.get("events") async def asyncTearDown(self) -> None: await Tortoise._drop_databases() diff --git a/tests/utils/test_run_async.py b/tests/utils/test_run_async.py index c181e057a..1508cad52 100644 --- a/tests/utils/test_run_async.py +++ b/tests/utils/test_run_async.py @@ -1,7 +1,7 @@ import os from unittest import skipIf -from tortoise import Tortoise, run_async +from tortoise import Tortoise, connections, run_async from tortoise.contrib.test import TestCase @@ -19,25 +19,25 @@ def setUp(self): async def init(self): await Tortoise.init(db_url="sqlite://:memory:", modules={"models": []}) self.somevalue = 2 - self.assertNotEqual(Tortoise._connections, {}) + self.assertNotEqual(connections._get_storage(), {}) async def init_raise(self): await Tortoise.init(db_url="sqlite://:memory:", modules={"models": []}) self.somevalue = 3 - self.assertNotEqual(Tortoise._connections, {}) + self.assertNotEqual(connections._get_storage(), {}) raise Exception("Some exception") def test_run_async(self): - self.assertEqual(Tortoise._connections, {}) + self.assertEqual(connections._get_storage(), {}) self.assertEqual(self.somevalue, 1) run_async(self.init()) - self.assertEqual(Tortoise._connections, {}) + self.assertEqual(connections._get_storage(), {}) self.assertEqual(self.somevalue, 2) def test_run_async_raised(self): - self.assertEqual(Tortoise._connections, {}) + self.assertEqual(connections._get_storage(), {}) self.assertEqual(self.somevalue, 1) with self.assertRaises(Exception): run_async(self.init_raise()) - self.assertEqual(Tortoise._connections, {}) + self.assertEqual(connections._get_storage(), {}) self.assertEqual(self.somevalue, 3) diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 3b4e36986..e46917488 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -3,7 +3,6 @@ import json import os import warnings -from contextvars import ContextVar from copy import deepcopy from inspect import isclass from types import ModuleType @@ -13,6 +12,7 @@ from tortoise.backends.base.client import BaseDBAsyncClient from tortoise.backends.base.config_generator import expand_db_url, generate_config +from tortoise.connection import connections from tortoise.exceptions import ConfigurationError from tortoise.fields.relational import ( BackwardFKRelation, @@ -24,13 +24,11 @@ from tortoise.filters import get_m2m_filters from tortoise.log import logger from tortoise.models import Model, ModelMeta -from tortoise.transactions import current_transaction_map from tortoise.utils import generate_schema_for_client class Tortoise: apps: Dict[str, Dict[str, Type["Model"]]] = {} - _connections: Dict[str, BaseDBAsyncClient] = {} _inited: bool = False @classmethod @@ -38,9 +36,13 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient: """ Returns the connection by name. - :raises KeyError: If connection name does not exist. + :raises ConfigurationError: If connection name does not exist. + + .. warning:: + This is deprecated and will be removed in a future release. Please use + :meth:`connections.get` instead. """ - return cls._connections[connection_name] + return connections.get(connection_name) @classmethod def describe_model( @@ -336,17 +338,6 @@ def split_reference(reference: str) -> Tuple[str, str]: model._meta.filters.update(get_m2m_filters(field, m2m_object)) related_model._meta.add_field(backward_relation_name, m2m_relation) - @classmethod - def _discover_client_class(cls, engine: str) -> Type[BaseDBAsyncClient]: - # Let exception bubble up for transparency - engine_module = importlib.import_module(engine) - - try: - client_class = engine_module.client_class - except AttributeError: - raise ConfigurationError(f'Backend for engine "{engine}" does not implement db client') - return client_class - @classmethod def _discover_models( cls, models_path: Union[ModuleType, str], app_label: str @@ -376,21 +367,6 @@ def _discover_models( warnings.warn(f'Module "{models_path}" has no models', RuntimeWarning, stacklevel=4) return discovered_models - @classmethod - async def _init_connections(cls, connections_config: dict, create_db: bool) -> None: - for name, info in connections_config.items(): - if isinstance(info, str): - info = expand_db_url(info) - client_class = cls._discover_client_class(info.get("engine")) - db_params = info["credentials"].copy() - db_params.update({"connection_name": name}) - connection = client_class(**db_params) - if create_db: - await connection.db_create() - await connection.create_connection(with_db=True) - cls._connections[name] = connection - current_transaction_map[name] = ContextVar(name, default=connection) - @classmethod def init_models( cls, @@ -423,7 +399,7 @@ def init_models( def _init_apps(cls, apps_config: dict) -> None: for name, info in apps_config.items(): try: - cls.get_connection(info.get("default_connection", "default")) + connections.get(info.get("default_connection", "default")) except KeyError: raise ConfigurationError( 'Unknown connection "{}" for app "{}"'.format( @@ -542,7 +518,7 @@ async def init( :raises ConfigurationError: For any configuration error """ if cls._inited: - await cls.close_connections() + await connections.close_all(discard=True) await cls._reset_apps() if int(bool(config) + bool(config_file) + bool(db_url)) != 1: raise ConfigurationError( @@ -595,7 +571,7 @@ async def init( ) cls._init_timezone(use_tz, timezone) - await cls._init_connections(connections_config, _create_db) + await connections._init(connections_config, _create_db) cls._init_apps(apps_config) cls._init_routers(routers) @@ -628,12 +604,12 @@ async def close_connections(cls) -> None: It is required for this to be called on exit, else your event loop may never complete as it is waiting for the connections to die. + + .. warning:: + This is deprecated and will be removed in a future release. Please use + :meth:`connections.close_all` instead. """ - tasks = [] - for connection in cls._connections.values(): - tasks.append(connection.close()) - await asyncio.gather(*tasks) - cls._connections = {} + await connections.close_all() logger.info("Tortoise-ORM shutdown") @classmethod @@ -643,7 +619,6 @@ async def _reset_apps(cls) -> None: if isinstance(model, ModelMeta): model._meta.default_connection = None cls.apps.clear() - current_transaction_map.clear() @classmethod async def generate_schemas(cls, safe: bool = True) -> None: @@ -658,7 +633,7 @@ async def generate_schemas(cls, safe: bool = True) -> None: """ if not cls._inited: raise ConfigurationError("You have to call .init() first before generating schemas") - for connection in cls._connections.values(): + for connection in connections.all(): await generate_schema_for_client(connection, safe) @classmethod @@ -671,10 +646,13 @@ async def _drop_databases(cls) -> None: """ if not cls._inited: raise ConfigurationError("You have to call .init() first before deleting schemas") - for connection in cls._connections.values(): - await connection.close() - await connection.db_delete() - cls._connections = {} + # this closes any existing connections/pool if any and clears + # the storage + await connections.close_all(discard=False) + for conn in connections.all(): + await conn.db_delete() + connections.discard(conn.connection_name) + await cls._reset_apps() @classmethod @@ -706,7 +684,7 @@ async def do_stuff(): try: loop.run_until_complete(coro) finally: - loop.run_until_complete(Tortoise.close_connections()) + loop.run_until_complete(connections.close_all(discard=True)) __version__ = "0.18.2" diff --git a/tortoise/backends/asyncpg/client.py b/tortoise/backends/asyncpg/client.py index 15412b16d..69412920b 100644 --- a/tortoise/backends/asyncpg/client.py +++ b/tortoise/backends/asyncpg/client.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Callable, List, Optional, Tuple, TypeVar +from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union import asyncpg from asyncpg.transaction import Transaction @@ -10,6 +10,7 @@ BaseTransactionWrapper, ConnectionWrapper, NestedTransactionPooledContext, + PoolConnectionWrapper, TransactionContext, TransactionContextPooled, ) @@ -88,6 +89,13 @@ async def db_delete(self) -> None: return await super().db_delete() except asyncpg.InvalidCatalogNameError: # pragma: nocoverage pass + await self.close() + + def acquire_connection(self) -> Union["PoolConnectionWrapper", "ConnectionWrapper"]: + return PoolConnectionWrapper(self) + + def _in_transaction(self) -> "TransactionContext": + return TransactionContextPooled(TransactionWrapper(self)) @translate_exceptions async def execute_insert(self, query: str, values: list) -> Optional[asyncpg.Record]: @@ -140,9 +148,6 @@ async def execute_query_dict(self, query: str, values: Optional[list] = None) -> return list(map(dict, await connection.fetch(query, *values))) return list(map(dict, await connection.fetch(query))) - def _in_transaction(self) -> "TransactionContext": - return TransactionContextPooled(TransactionWrapper(self)) - class TransactionWrapper(AsyncpgDBClient, BaseTransactionWrapper): def __init__(self, connection: AsyncpgDBClient) -> None: @@ -153,13 +158,13 @@ def __init__(self, connection: AsyncpgDBClient) -> None: self.connection_name = connection.connection_name self.transaction: Transaction = None self._finalized = False - self._parent = connection + self._parent: AsyncpgDBClient = connection def _in_transaction(self) -> "TransactionContext": return NestedTransactionPooledContext(self) def acquire_connection(self) -> "ConnectionWrapper": - return ConnectionWrapper(self._connection, self._lock) + return ConnectionWrapper(self._lock, self) @translate_exceptions async def execute_many(self, query: str, values: list) -> None: diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index 3070f252a..5972272e6 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -5,9 +5,9 @@ from tortoise.backends.base.executor import BaseExecutor from tortoise.backends.base.schema_generator import BaseSchemaGenerator +from tortoise.connection import connections from tortoise.exceptions import TransactionManagementError from tortoise.log import db_client_logger -from tortoise.transactions import current_transaction_map class Capabilities: @@ -203,13 +203,21 @@ async def execute_query_dict(self, query: str, values: Optional[list] = None) -> class ConnectionWrapper: - __slots__ = ("connection", "lock") + __slots__ = ("connection", "lock", "client") - def __init__(self, connection: Any, lock: asyncio.Lock) -> None: - self.connection = connection + def __init__(self, lock: asyncio.Lock, client: Any) -> None: + """Wraps the connections with a lock to facilitate safe concurrent access.""" self.lock: asyncio.Lock = lock + self.client = client + self.connection: Any = client._connection + + async def ensure_connection(self) -> None: + if not self.connection: + await self.client.create_connection(with_db=True) + self.connection = self.client._connection async def __aenter__(self): + await self.ensure_connection() await self.lock.acquire() return self.connection @@ -225,10 +233,15 @@ def __init__(self, connection: Any) -> None: self.connection_name = connection.connection_name self.lock = getattr(connection, "_trxlock", None) + async def ensure_connection(self) -> None: + if not self.connection._connection: + await self.connection._parent.create_connection(with_db=True) + self.connection._connection = self.connection._parent._connection + async def __aenter__(self): + await self.ensure_connection() await self.lock.acquire() # type:ignore - current_transaction = current_transaction_map[self.connection_name] - self.token = current_transaction.set(self.connection) + self.token = connections.set(self.connection_name, self.connection) await self.connection.start() return self.connection @@ -240,16 +253,20 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.connection.rollback() else: await self.connection.commit() - current_transaction_map[self.connection_name].reset(self.token) + connections.reset(self.token) self.lock.release() # type:ignore class TransactionContextPooled(TransactionContext): - __slots__ = ("connection", "connection_name", "token") + __slots__ = ("conn_wrapper", "connection", "connection_name", "token") + + async def ensure_connection(self) -> None: + if not self.connection._parent._pool: + await self.connection._parent.create_connection(with_db=True) async def __aenter__(self): - current_transaction = current_transaction_map[self.connection_name] - self.token = current_transaction.set(self.connection) + await self.ensure_connection() + self.token = connections.set(self.connection_name, self.connection) self.connection._connection = await self.connection._parent._pool.acquire() await self.connection.start() return self.connection @@ -262,9 +279,9 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.connection.rollback() else: await self.connection.commit() - current_transaction_map[self.connection_name].reset(self.token) if self.connection._parent._pool: await self.connection._parent._pool.release(self.connection._connection) + connections.reset(self.token) class NestedTransactionContext(TransactionContext): @@ -294,11 +311,19 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class PoolConnectionWrapper: - def __init__(self, pool: Any) -> None: - self.pool = pool + def __init__(self, client: Any) -> None: + """Class to manage acquiring from and releasing connections to a pool.""" + self.pool = client._pool + self.client = client self.connection = None + async def ensure_connection(self) -> None: + if not self.pool: + await self.client.create_connection(with_db=True) + self.pool = self.client._pool + async def __aenter__(self): + await self.ensure_connection() # get first available connection self.connection = await self.pool.acquire() return self.connection diff --git a/tortoise/backends/mysql/client.py b/tortoise/backends/mysql/client.py index 675f27dc5..cdf54edc3 100644 --- a/tortoise/backends/mysql/client.py +++ b/tortoise/backends/mysql/client.py @@ -159,7 +159,7 @@ async def db_delete(self) -> None: await self.close() def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: - return PoolConnectionWrapper(self._pool) + return PoolConnectionWrapper(self) def _in_transaction(self) -> "TransactionContext": return TransactionContextPooled(TransactionWrapper(self)) @@ -229,7 +229,7 @@ def _in_transaction(self) -> "TransactionContext": return NestedTransactionPooledContext(self) def acquire_connection(self) -> ConnectionWrapper: - return ConnectionWrapper(self._connection, self._lock) + return ConnectionWrapper(self._lock, self) @translate_exceptions async def execute_many(self, query: str, values: list) -> None: diff --git a/tortoise/backends/psycopg/client.py b/tortoise/backends/psycopg/client.py index ff358d477..97a53a047 100644 --- a/tortoise/backends/psycopg/client.py +++ b/tortoise/backends/psycopg/client.py @@ -12,6 +12,7 @@ import tortoise.backends.base_postgres.client as postgres_client import tortoise.backends.psycopg.executor as executor import tortoise.exceptions as exceptions +from tortoise.backends.base.client import PoolConnectionWrapper from tortoise.backends.psycopg.schema_generator import PsycopgSchemaGenerator FuncType = typing.Callable[..., typing.Any] @@ -27,61 +28,9 @@ async def release(self, connection: psycopg.AsyncConnection): await self.putconn(connection) -class PoolConnectionWrapper(base_client.PoolConnectionWrapper): - pool: AsyncConnectionPool - # Using _connection instead of connection because of mypy - _connection: typing.Optional[psycopg.AsyncConnection] = None - - def __init__(self, pool: AsyncConnectionPool) -> None: - self.pool = pool - - async def open(self, timeout: float = 30.0): - try: - await self.pool.open(wait=True, timeout=timeout) - except psycopg_pool.PoolTimeout as exception: - raise exceptions.DBConnectionError from exception - - async def close(self, timeout: float = 0.0): - await self.pool.close(timeout=timeout) - - async def __aenter__(self) -> psycopg.AsyncConnection: - if self.pool is None: - raise RuntimeError("Connection pool is not initialized") - - if self._connection: - raise RuntimeError("Connection is already acquired") - else: - self._connection = await self.pool.getconn() - - return await self._connection.__aenter__() - - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - if self._connection: - await self._connection.__aexit__(exc_type, exc_val, exc_tb) - await self.pool.putconn(self._connection) - - self._connection = None - - # TortoiseORM has this interface hardcoded in the tests, so we need to support it - async def acquire(self) -> psycopg.AsyncConnection: - if not self._connection: - await self.__aenter__() - - if not self._connection: - raise RuntimeError("Connection is not acquired") - return self._connection - - # TortoiseORM has this interface hardcoded in the tests, so we need to support it - async def release(self, connection: psycopg.AsyncConnection) -> None: - if self._connection is not connection: - raise RuntimeError("Wrong connection is being released") - await self.__aexit__(None, None, None) - - class PsycopgClient(postgres_client.BasePostgresClient): executor_class: typing.Type[executor.PsycopgExecutor] = executor.PsycopgExecutor schema_generator: typing.Type[PsycopgSchemaGenerator] = PsycopgSchemaGenerator - # _pool: typing.Optional[PoolConnectionWrapper] = None _pool: typing.Optional[AsyncConnectionPool] = None _connection: psycopg.AsyncConnection default_timeout: float = 30 @@ -230,11 +179,7 @@ async def _translate_exceptions(self, func, *args, **kwargs) -> Exception: def acquire_connection( self, ) -> typing.Union[base_client.ConnectionWrapper, PoolConnectionWrapper]: - if not self._pool: - raise exceptions.OperationalError("Connection pool not initialized") - - pool_wrapper: PoolConnectionWrapper = PoolConnectionWrapper(self._pool) - return pool_wrapper + return PoolConnectionWrapper(self) def _in_transaction(self) -> base_client.TransactionContext: return base_client.TransactionContextPooled(TransactionWrapper(self)) @@ -256,7 +201,7 @@ def _in_transaction(self) -> base_client.TransactionContext: return base_client.NestedTransactionPooledContext(self) def acquire_connection(self) -> base_client.ConnectionWrapper: - return base_client.ConnectionWrapper(self._connection, self._lock) + return base_client.ConnectionWrapper(self._lock, self) @postgres_client.translate_exceptions async def start(self) -> None: diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index 5f631e029..fe2071d25 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -101,7 +101,7 @@ async def db_delete(self) -> None: raise e def acquire_connection(self) -> ConnectionWrapper: - return ConnectionWrapper(self._connection, self._lock) + return ConnectionWrapper(self._lock, self) def _in_transaction(self) -> "TransactionContext": return TransactionContext(TransactionWrapper(self)) @@ -160,6 +160,7 @@ def __init__(self, connection: SqliteClient) -> None: self.log = connection.log self._finalized = False self.fetch_inserted = connection.fetch_inserted + self._parent = connection def _in_transaction(self) -> "TransactionContext": return NestedTransactionContext(self) diff --git a/tortoise/connection.py b/tortoise/connection.py new file mode 100644 index 000000000..ccc56c0ff --- /dev/null +++ b/tortoise/connection.py @@ -0,0 +1,192 @@ +import asyncio +import contextvars +import importlib +from contextvars import ContextVar +from copy import copy +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union + +from tortoise.backends.base.config_generator import expand_db_url +from tortoise.exceptions import ConfigurationError + +if TYPE_CHECKING: + from tortoise.backends.base.client import BaseDBAsyncClient + + DBConfigType = Dict[str, Any] + + +class ConnectionHandler: + _conn_storage: ContextVar[Dict[str, "BaseDBAsyncClient"]] = contextvars.ContextVar( + "_conn_storage", default={} + ) + + def __init__(self) -> None: + """Unified connection management interface.""" + self._db_config: Optional["DBConfigType"] = None + self._create_db: bool = False + + async def _init(self, db_config: "DBConfigType", create_db: bool): + self._db_config = db_config + self._create_db = create_db + await self._init_connections() + + @property + def db_config(self) -> "DBConfigType": + """ + Return the DB config. + + This is the same config passed to the + :meth:`Tortoise.init` method while initialization. + + :raises ConfigurationError: + If this property is accessed before calling the + :meth:`Tortoise.init` method. + """ + if self._db_config is None: + raise ConfigurationError( + "DB configuration not initialised. Make sure to call " + "Tortoise.init with a valid configuration before attempting " + "to create connections." + ) + return self._db_config + + def _get_storage(self) -> Dict[str, "BaseDBAsyncClient"]: + return self._conn_storage.get() + + def _set_storage(self, new_storage: Dict[str, "BaseDBAsyncClient"]) -> contextvars.Token: + # Should be used only for testing purposes. + return self._conn_storage.set(new_storage) + + def _copy_storage(self) -> Dict[str, "BaseDBAsyncClient"]: + return copy(self._get_storage()) + + def _clear_storage(self) -> None: + self._get_storage().clear() + + def _discover_client_class(self, engine: str) -> Type["BaseDBAsyncClient"]: + # Let exception bubble up for transparency + engine_module = importlib.import_module(engine) + try: + client_class = engine_module.client_class + except AttributeError: + raise ConfigurationError(f'Backend for engine "{engine}" does not implement db client') + return client_class + + def _get_db_info(self, conn_alias: str) -> Union[str, Dict]: + try: + return self.db_config[conn_alias] + except KeyError: + raise ConfigurationError( + f"Unable to get db settings for alias '{conn_alias}'. Please " + f"check if the config dict contains this alias and try again" + ) + + async def _init_connections(self) -> None: + for alias in self.db_config: + connection: "BaseDBAsyncClient" = self.get(alias) + if self._create_db: + await connection.db_create() + + def _create_connection(self, conn_alias: str) -> "BaseDBAsyncClient": + db_info = self._get_db_info(conn_alias) + if isinstance(db_info, str): + db_info = expand_db_url(db_info) + client_class = self._discover_client_class(db_info.get("engine", "")) + db_params = db_info["credentials"].copy() + db_params.update({"connection_name": conn_alias}) + connection: "BaseDBAsyncClient" = client_class(**db_params) + return connection + + def get(self, conn_alias: str) -> "BaseDBAsyncClient": + """ + Return the connection object for the given alias, creating it if needed. + + Used for accessing the low-level connection object + (:class:`BaseDBAsyncClient`) for the + given alias. + + :param conn_alias: The alias for which the connection has to be fetched + + :raises ConfigurationError: If the connection alias does not exist. + """ + storage: Dict[str, "BaseDBAsyncClient"] = self._get_storage() + try: + return storage[conn_alias] + except KeyError: + connection: BaseDBAsyncClient = self._create_connection(conn_alias) + storage[conn_alias] = connection + return connection + + def set(self, conn_alias: str, conn_obj: "BaseDBAsyncClient") -> contextvars.Token: + """ + Sets the given alias to the provided connection object. + + :param conn_alias: The alias to set the connection for. + :param conn_obj: The connection object that needs to be set for this alias. + + .. note:: + This method copies the storage from the `current context`, updates the + ``conn_alias`` with the provided ``conn_obj`` and sets the updated storage + in a `new context` and therefore returns a ``contextvars.Token`` in order to restore + the original context storage. + """ + storage_copy = self._copy_storage() + storage_copy[conn_alias] = conn_obj + return self._conn_storage.set(storage_copy) + + def discard(self, conn_alias: str) -> Optional["BaseDBAsyncClient"]: + """ + Discards the given alias from the storage in the `current context`. + + :param conn_alias: The alias for which the connection object should be discarded. + + .. important:: + Make sure to have called ``conn.close()`` for the provided alias before calling + this method else there would be a connection leak (dangling connection). + """ + return self._get_storage().pop(conn_alias, None) + + def reset(self, token: contextvars.Token) -> None: + """ + Reset the underlying storage to the previous context state. + + Resets the storage state to the `context` associated with the provided token. After + resetting storage state, any additional `connections` created in the `old context` are + copied into the `current context`. + + :param token: + The token corresponding to the `context` to which the storage state has to + be reset. Typically, this token is obtained by calling the + :meth:`set` method of this class. + """ + current_storage = self._get_storage() + self._conn_storage.reset(token) + prev_storage = self._get_storage() + for alias, conn in current_storage.items(): + if alias not in prev_storage: + prev_storage[alias] = conn + + def all(self) -> List["BaseDBAsyncClient"]: + """Returns a list of connection objects from the storage in the `current context`.""" + # The reason this method iterates over db_config and not over `storage` directly is + # because: assume that someone calls `discard` with a certain alias, and calls this + # method subsequently. The alias which just got discarded from the storage would not + # appear in the returned list though it exists as part of the `db_config`. + return [self.get(alias) for alias in self.db_config] + + async def close_all(self, discard: bool = True) -> None: + """ + Closes all connections in the storage in the `current context`. + + All closed connections will be removed from the storage by default. + + :param discard: + If ``False``, all connection objects are closed but `retained` in the storage. + """ + tasks = [conn.close() for conn in self.all()] + await asyncio.gather(*tasks) + if discard: + for alias in self.db_config: + self.discard(alias) + + +connections = ConnectionHandler() diff --git a/tortoise/contrib/aiohttp/__init__.py b/tortoise/contrib/aiohttp/__init__.py index 085bcb966..c2a171221 100644 --- a/tortoise/contrib/aiohttp/__init__.py +++ b/tortoise/contrib/aiohttp/__init__.py @@ -3,7 +3,7 @@ from aiohttp import web # pylint: disable=E0401 -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.log import logger @@ -79,13 +79,13 @@ def register_tortoise( async def init_orm(app): # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info(f"Tortoise-ORM started, {Tortoise._connections}, {Tortoise.apps}") + logger.info(f"Tortoise-ORM started, {connections._get_storage()}, {Tortoise.apps}") if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() async def close_orm(app): # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") app.on_startup.append(init_orm) diff --git a/tortoise/contrib/blacksheep/__init__.py b/tortoise/contrib/blacksheep/__init__.py index bcde53cb7..6b832ffbb 100644 --- a/tortoise/contrib/blacksheep/__init__.py +++ b/tortoise/contrib/blacksheep/__init__.py @@ -5,7 +5,7 @@ from blacksheep.server import Application from blacksheep.server.responses import json -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.exceptions import DoesNotExist, IntegrityError from tortoise.log import logger @@ -87,14 +87,14 @@ def register_tortoise( @app.on_start async def init_orm(context) -> None: # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", Tortoise._connections, Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() @app.on_stop async def close_orm(context) -> None: # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") if add_exception_handlers: diff --git a/tortoise/contrib/fastapi/__init__.py b/tortoise/contrib/fastapi/__init__.py index 45a5f3ea9..8912cac1f 100644 --- a/tortoise/contrib/fastapi/__init__.py +++ b/tortoise/contrib/fastapi/__init__.py @@ -5,7 +5,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel # pylint: disable=E0611 -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.exceptions import DoesNotExist, IntegrityError from tortoise.log import logger @@ -91,14 +91,14 @@ def register_tortoise( @app.on_event("startup") async def init_orm() -> None: # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", Tortoise._connections, Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() @app.on_event("shutdown") async def close_orm() -> None: # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") if add_exception_handlers: diff --git a/tortoise/contrib/quart/__init__.py b/tortoise/contrib/quart/__init__.py index e13c25053..184d3b0e2 100644 --- a/tortoise/contrib/quart/__init__.py +++ b/tortoise/contrib/quart/__init__.py @@ -5,7 +5,7 @@ from quart import Quart # pylint: disable=E0401 -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.log import logger @@ -84,14 +84,14 @@ def register_tortoise( @app.before_serving async def init_orm() -> None: # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", Tortoise._connections, Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) if _generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() @app.after_serving async def close_orm() -> None: # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") @app.cli.command() # type: ignore @@ -103,7 +103,7 @@ async def inner() -> None: config=config, config_file=config_file, db_url=db_url, modules=modules ) await Tortoise.generate_schemas() - await Tortoise.close_connections() + await connections.close_all() logger.setLevel(logging.DEBUG) loop = asyncio.get_event_loop() diff --git a/tortoise/contrib/sanic/__init__.py b/tortoise/contrib/sanic/__init__.py index c1fd8e526..7430dfd52 100644 --- a/tortoise/contrib/sanic/__init__.py +++ b/tortoise/contrib/sanic/__init__.py @@ -3,7 +3,7 @@ from sanic import Sanic # pylint: disable=E0401 -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.log import logger @@ -80,12 +80,12 @@ def register_tortoise( @app.listener("before_server_start") # type:ignore async def init_orm(app, loop): # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", Tortoise._connections, Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() @app.listener("after_server_stop") # type:ignore async def close_orm(app, loop): # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") diff --git a/tortoise/contrib/starlette/__init__.py b/tortoise/contrib/starlette/__init__.py index d5952155e..a42c63f3f 100644 --- a/tortoise/contrib/starlette/__init__.py +++ b/tortoise/contrib/starlette/__init__.py @@ -3,7 +3,7 @@ from starlette.applications import Starlette # pylint: disable=E0401 -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.log import logger @@ -80,12 +80,12 @@ def register_tortoise( @app.on_event("startup") async def init_orm() -> None: # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", Tortoise._connections, Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() @app.on_event("shutdown") async def close_orm() -> None: # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index 4554941f8..2951114a5 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -1,5 +1,4 @@ import asyncio -import os import os as _os import unittest from asyncio.events import AbstractEventLoop @@ -8,10 +7,9 @@ from typing import Any, Iterable, List, Optional, Union from unittest import SkipTest, expectedFailure, skip, skipIf, skipUnless -from tortoise import Model, Tortoise +from tortoise import Model, Tortoise, connections from tortoise.backends.base.config_generator import generate_config as _generate_config -from tortoise.exceptions import DBConnectionError -from tortoise.transactions import current_transaction_map +from tortoise.exceptions import DBConnectionError, OperationalError __all__ = ( "SimpleTestCase", @@ -43,7 +41,7 @@ _SELECTOR = None _LOOP: AbstractEventLoop = None # type: ignore _MODULES: Iterable[Union[str, ModuleType]] = [] -_CONN_MAP: dict = {} +_CONN_CONFIG: dict = {} def getDBConfig(app_label: str, modules: Iterable[Union[str, ModuleType]]) -> dict: @@ -62,10 +60,12 @@ def getDBConfig(app_label: str, modules: Iterable[Union[str, ModuleType]]) -> di async def _init_db(config: dict) -> None: + # Placing init outside the try block since it doesn't + # establish connections to the DB eagerly. + await Tortoise.init(config) try: - await Tortoise.init(config) await Tortoise._drop_databases() - except DBConnectionError: # pragma: nocoverage + except (DBConnectionError, OperationalError): # pragma: nocoverage pass await Tortoise.init(config, _create_db=True) @@ -74,8 +74,8 @@ async def _init_db(config: dict) -> None: def _restore_default() -> None: Tortoise.apps = {} - Tortoise._connections = _CONNECTIONS.copy() - current_transaction_map.update(_CONN_MAP) + connections._get_storage().update(_CONNECTIONS.copy()) + connections._db_config = _CONN_CONFIG.copy() Tortoise._init_apps(_CONFIG["apps"]) Tortoise._inited = True @@ -101,20 +101,20 @@ def initializer( global _LOOP global _TORTOISE_TEST_DB global _MODULES - global _CONN_MAP + global _CONN_CONFIG _MODULES = modules if db_url is not None: # pragma: nobranch _TORTOISE_TEST_DB = db_url _CONFIG = getDBConfig(app_label=app_label, modules=_MODULES) - loop = loop or asyncio.get_event_loop() _LOOP = loop _SELECTOR = loop._selector # type: ignore loop.run_until_complete(_init_db(_CONFIG)) - _CONNECTIONS = Tortoise._connections.copy() - _CONN_MAP = current_transaction_map.copy() + _CONNECTIONS = connections._copy_storage() + _CONN_CONFIG = connections.db_config.copy() + connections._clear_storage() + connections.db_config.clear() Tortoise.apps = {} - Tortoise._connections = {} Tortoise._inited = False @@ -189,10 +189,15 @@ def _tearDownAsyncioLoop(self): async def asyncSetUp(self) -> None: await self._setUpDB() + def _reset_conn_state(self) -> None: + # clearing the storage and db config + connections._clear_storage() + connections.db_config.clear() + async def asyncTearDown(self) -> None: await self._tearDownDB() + self._reset_conn_state() Tortoise.apps = {} - Tortoise._connections = {} Tortoise._inited = False def assertListSortEqual( @@ -231,13 +236,12 @@ class IsolatedTestCase(SimpleTestCase): tortoise_test_modules: Iterable[Union[str, ModuleType]] = [] async def _setUpDB(self) -> None: + await super()._setUpDB() config = getDBConfig(app_label="models", modules=self.tortoise_test_modules or _MODULES) await Tortoise.init(config, _create_db=True) await Tortoise.generate_schemas(safe=False) - self._connections = Tortoise._connections.copy() async def _tearDownDB(self) -> None: - Tortoise._connections = self._connections.copy() await Tortoise._drop_databases() @@ -268,25 +272,39 @@ async def _tearDownDB(self) -> None: class TransactionTestContext: - __slots__ = ("connection", "connection_name", "token") + __slots__ = ("connection", "connection_name", "token", "uses_pool") def __init__(self, connection) -> None: self.connection = connection self.connection_name = connection.connection_name + self.uses_pool = hasattr(self.connection._parent, "_pool") - async def __aenter__(self): - current_transaction = current_transaction_map[self.connection_name] - self.token = current_transaction.set(self.connection) - if hasattr(self.connection, "_parent"): + async def ensure_connection(self) -> None: + is_conn_established = self.connection._connection is not None + if self.uses_pool: + is_conn_established = self.connection._parent._pool is not None + + # If the underlying pool/connection hasn't been established then + # first create the pool/connection + if not is_conn_established: + await self.connection._parent.create_connection(with_db=True) + + if self.uses_pool: self.connection._connection = await self.connection._parent._pool.acquire() + else: + self.connection._connection = self.connection._parent._connection + + async def __aenter__(self): + await self.ensure_connection() + self.token = connections.set(self.connection_name, self.connection) await self.connection.start() return self.connection async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.connection.rollback() - if hasattr(self.connection, "_parent"): + if self.uses_pool: await self.connection._parent._pool.release(self.connection._connection) - current_transaction_map[self.connection_name].reset(self.token) + connections.reset(self.token) class TestCase(TruncationTestCase): @@ -299,8 +317,7 @@ class TestCase(TruncationTestCase): async def asyncSetUp(self) -> None: await super(TestCase, self).asyncSetUp() - _restore_default() - self.__db__ = Tortoise.get_connection("models") + self.__db__ = connections.get("models") self.__transaction__ = TransactionTestContext(self.__db__._in_transaction().connection) await self.__transaction__.__aenter__() # type: ignore @@ -350,7 +367,7 @@ def decorator(test_item): @wraps(test_item) def skip_wrapper(*args, **kwargs): - db = Tortoise.get_connection(connection_name) + db = connections.get(connection_name) for key, val in conditions.items(): if getattr(db.capabilities, key) != val: raise SkipTest(f"Capability {key} != {val}") diff --git a/tortoise/models.py b/tortoise/models.py index af7cb9487..0fd730d84 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -22,6 +22,7 @@ from pypika import Order, Query, Table from pypika.terms import Term +from tortoise import connections from tortoise.backends.base.client import BaseDBAsyncClient from tortoise.exceptions import ( ConfigurationError, @@ -59,7 +60,7 @@ ) from tortoise.router import router from tortoise.signals import Signals -from tortoise.transactions import current_transaction_map, in_transaction +from tortoise.transactions import in_transaction MODEL = TypeVar("MODEL", bound="Model") EMPTY = object() @@ -278,10 +279,11 @@ def add_field(self, name: str, value: Field) -> None: @property def db(self) -> BaseDBAsyncClient: - try: - return current_transaction_map[self.default_connection].get() - except KeyError: - raise ConfigurationError("No DB associated to model") + if self.default_connection is None: + raise ConfigurationError( + f"default_connection for the model {self._model} cannot be None" + ) + return connections.get(self.default_connection) @property def ordering(self) -> Tuple[Tuple[str, Order], ...]: diff --git a/tortoise/router.py b/tortoise/router.py index 885398d70..7ce5f8503 100644 --- a/tortoise/router.py +++ b/tortoise/router.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, List, Optional, Type -from tortoise.exceptions import ParamsError -from tortoise.transactions import get_connection +from tortoise.connection import connections +from tortoise.exceptions import ConfigurationError if TYPE_CHECKING: from tortoise import BaseDBAsyncClient, Model @@ -28,8 +28,8 @@ def _router_func(self, model: Type["Model"], action: str): def _db_route(self, model: Type["Model"], action: str): try: - return get_connection(self._router_func(model, action)) - except ParamsError: + return connections.get(self._router_func(model, action)) + except ConfigurationError: return None def db_for_read(self, model: Type["Model"]) -> Optional["BaseDBAsyncClient"]: diff --git a/tortoise/transactions.py b/tortoise/transactions.py index 02c41f122..8a1d92b88 100644 --- a/tortoise/transactions.py +++ b/tortoise/transactions.py @@ -1,10 +1,9 @@ from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast +from tortoise import connections from tortoise.exceptions import ParamsError -current_transaction_map: dict = {} - if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.base.client import BaseDBAsyncClient, TransactionContext @@ -12,18 +11,16 @@ F = TypeVar("F", bound=FuncType) -def get_connection(connection_name: Optional[str]) -> "BaseDBAsyncClient": - from tortoise import Tortoise - +def _get_connection(connection_name: Optional[str]) -> "BaseDBAsyncClient": if connection_name: - connection = current_transaction_map[connection_name].get() - elif len(Tortoise._connections) == 1: - connection_name = list(Tortoise._connections.keys())[0] - connection = current_transaction_map[connection_name].get() + connection = connections.get(connection_name) + elif len(connections.db_config) == 1: + connection_name = next(iter(connections.db_config.keys())) + connection = connections.get(connection_name) else: raise ParamsError( "You are running with multiple databases, so you should specify" - f" connection_name: {list(Tortoise._connections.keys())}" + f" connection_name: {list(connections.db_config)}" ) return connection @@ -38,7 +35,7 @@ def in_transaction(connection_name: Optional[str] = None) -> "TransactionContext :param connection_name: name of connection to run with, optional if you have only one db connection """ - connection = get_connection(connection_name) + connection = _get_connection(connection_name) return connection._in_transaction()