From c2a398b28f4995f721dd46e3a5d56267b2e717c9 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Sun, 5 Dec 2021 17:01:38 +0530 Subject: [PATCH 01/25] - Initial working commit for refactored connection management. - Tests are passing for SQLite DB. --- tests/model_setup/test__models__.py | 6 +- tests/model_setup/test_init.py | 722 ++++++++++----------- tests/schema/test_generate_schema.py | 12 +- tests/test_two_databases.py | 6 +- tests/utils/test_run_async.py | 14 +- tortoise/__init__.py | 27 +- tortoise/backends/asyncpg/client.py | 6 +- tortoise/backends/base/client.py | 48 +- tortoise/backends/base/schema_generator.py | 1 - tortoise/backends/mysql/client.py | 4 +- tortoise/backends/sqlite/client.py | 3 +- tortoise/connection.py | 123 ++++ tortoise/contrib/test/__init__.py | 77 ++- tortoise/models.py | 6 +- tortoise/router.py | 8 +- tortoise/transactions.py | 16 +- 16 files changed, 629 insertions(+), 450 deletions(-) create mode 100644 tortoise/connection.py diff --git a/tests/model_setup/test__models__.py b/tests/model_setup/test__models__.py index 90ce365bd..cbd81f136 100644 --- a/tests/model_setup/test__models__.py +++ b/tests/model_setup/test__models__.py @@ -5,7 +5,7 @@ import re from unittest.mock import AsyncMock, patch -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.exceptions import ConfigurationError from tortoise.utils import get_schema_sql @@ -15,7 +15,6 @@ class TestGenerateSchema(test.SimpleTestCase): async def setUp(self): try: Tortoise.apps = {} - Tortoise._connections = {} Tortoise._inited = False except ConfigurationError: pass @@ -27,7 +26,6 @@ async def setUp(self): ] async def tearDown(self): - Tortoise._connections = {} await Tortoise._reset_apps() async def init_for(self, module: str, safe=False) -> None: @@ -47,7 +45,7 @@ async def init_for(self, module: str, safe=False) -> None: "apps": {"models": {"models": [module], "default_connection": "default"}}, } ) - self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split(";\n") + self.sqls = get_schema_sql(connections.get("default"), safe).split(";\n") def get_sql(self, text: str) -> str: return str(re.sub(r"[ \t\n\r]+", " ", [sql for sql in self.sqls if text in sql][0])) diff --git a/tests/model_setup/test_init.py b/tests/model_setup/test_init.py index 3b70921b2..fb4393a1a 100644 --- a/tests/model_setup/test_init.py +++ b/tests/model_setup/test_init.py @@ -1,361 +1,361 @@ -import os - -from tortoise import Tortoise -from tortoise.contrib import test -from tortoise.exceptions import ConfigurationError - - -class TestInitErrors(test.SimpleTestCase): - async def setUp(self): - try: - Tortoise.apps = {} - Tortoise._connections = {} - Tortoise._inited = False - except ConfigurationError: - pass - Tortoise._inited = False - - async def tearDown(self): - await Tortoise.close_connections() - await Tortoise._reset_apps() - - async def test_basic_init(self): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "default"} - }, - } - ) - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) - - async def test_empty_modules_init(self): - with self.assertWarnsRegex(RuntimeWarning, 'Module "tests.model_setup" has no models'): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.model_setup"], "default_connection": "default"} - }, - } - ) - - async def test_dup1_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'backward relation "events" duplicates in model Tournament' - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.models_dup1"], - "default_connection": "default", - } - }, - } - ) - - async def test_dup2_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'backward relation "events" duplicates in model Team' - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.models_dup2"], - "default_connection": "default", - } - }, - } - ) - - async def test_dup3_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'backward relation "event" duplicates in model Tournament' - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.models_dup3"], - "default_connection": "default", - } - }, - } - ) - - async def test_generated_nonint(self): - with self.assertRaisesRegex( - ConfigurationError, "Field 'val' \\(CharField\\) can't be DB-generated" - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.model_generated_nonint"], - "default_connection": "default", - } - }, - } - ) - - async def test_multiple_pk(self): - with self.assertRaisesRegex( - ConfigurationError, - "Can't create model Tournament with two primary keys, only single primary key is supported", - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.model_multiple_pk"], - "default_connection": "default", - } - }, - } - ) - - async def test_nonpk_id(self): - with self.assertRaisesRegex( - ConfigurationError, - "Can't create model Tournament without explicit primary key if" - " field 'id' already present", - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.model_nonpk_id"], - "default_connection": "default", - } - }, - } - ) - - async def test_unknown_connection(self): - with self.assertRaisesRegex(ConfigurationError, 'Unknown connection "fioop"'): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "fioop"} - }, - } - ) - - async def test_url_without_modules(self): - with self.assertRaisesRegex( - ConfigurationError, 'You must specify "db_url" and "modules" together' - ): - await Tortoise.init(db_url=f"sqlite://{':memory:'}") - - async def test_default_connection_init(self): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": {"models": {"models": ["tests.testmodels"]}}, - } - ) - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) - - async def test_db_url_init(self): - await Tortoise.init( - { - "connections": {"default": f"sqlite://{':memory:'}"}, - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "default"} - }, - } - ) - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) - - async def test_shorthand_init(self): - await Tortoise.init( - db_url=f"sqlite://{':memory:'}", modules={"models": ["tests.testmodels"]} - ) - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) - - async def test_init_wrong_connection_engine(self): - with self.assertRaisesRegex(ImportError, "tortoise.backends.test"): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.test", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "default"} - }, - } - ) - - async def test_init_wrong_connection_engine_2(self): - with self.assertRaisesRegex( - ConfigurationError, - 'Backend for engine "tortoise.backends" does not implement db client', - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "default"} - }, - } - ) - - async def test_init_no_connections(self): - with self.assertRaisesRegex(ConfigurationError, 'Config must define "connections" section'): - await Tortoise.init( - { - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "default"} - } - } - ) - - async def test_init_no_apps(self): - with self.assertRaisesRegex(ConfigurationError, 'Config must define "apps" section'): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - } - } - ) - - async def test_init_config_and_config_file(self): - with self.assertRaisesRegex( - ConfigurationError, 'You should init either from "config", "config_file" or "db_url"' - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "default"} - }, - }, - config_file="file.json", - ) - - async def test_init_config_file_wrong_extension(self): - with self.assertRaisesRegex( - ConfigurationError, "Unknown config extension .ini, only .yml and .json are supported" - ): - await Tortoise.init(config_file="config.ini") - - @test.skipIf(os.name == "nt", "path issue on Windows") - async def test_init_json_file(self): - await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.json") - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) - - @test.skipIf(os.name == "nt", "path issue on Windows") - async def test_init_yaml_file(self): - await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.yaml") - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(Tortoise.get_connection("default")) - - async def test_generate_schema_without_init(self): - with self.assertRaisesRegex( - ConfigurationError, r"You have to call \.init\(\) first before generating schemas" - ): - await Tortoise.generate_schemas() - - async def test_drop_databases_without_init(self): - with self.assertRaisesRegex( - ConfigurationError, r"You have to call \.init\(\) first before deleting schemas" - ): - await Tortoise._drop_databases() - - async def test_bad_models(self): - with self.assertRaisesRegex(ConfigurationError, 'Module "tests.testmodels2" not found'): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.testmodels2"], "default_connection": "default"} - }, - } - ) +# import os +# +# from tortoise import Tortoise +# from tortoise.contrib import test +# from tortoise.exceptions import ConfigurationError +# +# +# class TestInitErrors(test.SimpleTestCase): +# async def setUp(self): +# try: +# Tortoise.apps = {} +# Tortoise._connections = {} +# Tortoise._inited = False +# except ConfigurationError: +# pass +# Tortoise._inited = False +# +# async def tearDown(self): +# await Tortoise.close_connections() +# await Tortoise._reset_apps() +# +# async def test_basic_init(self): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": {"models": ["tests.testmodels"], "default_connection": "default"} +# }, +# } +# ) +# self.assertIn("models", Tortoise.apps) +# self.assertIsNotNone(Tortoise.get_connection("default")) +# +# async def test_empty_modules_init(self): +# with self.assertWarnsRegex(RuntimeWarning, 'Module "tests.model_setup" has no models'): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": {"models": ["tests.model_setup"], "default_connection": "default"} +# }, +# } +# ) +# +# async def test_dup1_init(self): +# with self.assertRaisesRegex( +# ConfigurationError, 'backward relation "events" duplicates in model Tournament' +# ): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": { +# "models": ["tests.model_setup.models_dup1"], +# "default_connection": "default", +# } +# }, +# } +# ) +# +# async def test_dup2_init(self): +# with self.assertRaisesRegex( +# ConfigurationError, 'backward relation "events" duplicates in model Team' +# ): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": { +# "models": ["tests.model_setup.models_dup2"], +# "default_connection": "default", +# } +# }, +# } +# ) +# +# async def test_dup3_init(self): +# with self.assertRaisesRegex( +# ConfigurationError, 'backward relation "event" duplicates in model Tournament' +# ): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": { +# "models": ["tests.model_setup.models_dup3"], +# "default_connection": "default", +# } +# }, +# } +# ) +# +# async def test_generated_nonint(self): +# with self.assertRaisesRegex( +# ConfigurationError, "Field 'val' \\(CharField\\) can't be DB-generated" +# ): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": { +# "models": ["tests.model_setup.model_generated_nonint"], +# "default_connection": "default", +# } +# }, +# } +# ) +# +# async def test_multiple_pk(self): +# with self.assertRaisesRegex( +# ConfigurationError, +# "Can't create model Tournament with two primary keys, only single primary key is supported", +# ): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": { +# "models": ["tests.model_setup.model_multiple_pk"], +# "default_connection": "default", +# } +# }, +# } +# ) +# +# async def test_nonpk_id(self): +# with self.assertRaisesRegex( +# ConfigurationError, +# "Can't create model Tournament without explicit primary key if" +# " field 'id' already present", +# ): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": { +# "models": ["tests.model_setup.model_nonpk_id"], +# "default_connection": "default", +# } +# }, +# } +# ) +# +# async def test_unknown_connection(self): +# with self.assertRaisesRegex(ConfigurationError, 'Unknown connection "fioop"'): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": {"models": ["tests.testmodels"], "default_connection": "fioop"} +# }, +# } +# ) +# +# async def test_url_without_modules(self): +# with self.assertRaisesRegex( +# ConfigurationError, 'You must specify "db_url" and "modules" together' +# ): +# await Tortoise.init(db_url=f"sqlite://{':memory:'}") +# +# async def test_default_connection_init(self): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": {"models": {"models": ["tests.testmodels"]}}, +# } +# ) +# self.assertIn("models", Tortoise.apps) +# self.assertIsNotNone(Tortoise.get_connection("default")) +# +# async def test_db_url_init(self): +# await Tortoise.init( +# { +# "connections": {"default": f"sqlite://{':memory:'}"}, +# "apps": { +# "models": {"models": ["tests.testmodels"], "default_connection": "default"} +# }, +# } +# ) +# self.assertIn("models", Tortoise.apps) +# self.assertIsNotNone(Tortoise.get_connection("default")) +# +# async def test_shorthand_init(self): +# await Tortoise.init( +# db_url=f"sqlite://{':memory:'}", modules={"models": ["tests.testmodels"]} +# ) +# self.assertIn("models", Tortoise.apps) +# self.assertIsNotNone(Tortoise.get_connection("default")) +# +# async def test_init_wrong_connection_engine(self): +# with self.assertRaisesRegex(ImportError, "tortoise.backends.test"): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.test", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": {"models": ["tests.testmodels"], "default_connection": "default"} +# }, +# } +# ) +# +# async def test_init_wrong_connection_engine_2(self): +# with self.assertRaisesRegex( +# ConfigurationError, +# 'Backend for engine "tortoise.backends" does not implement db client', +# ): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": {"models": ["tests.testmodels"], "default_connection": "default"} +# }, +# } +# ) +# +# async def test_init_no_connections(self): +# with self.assertRaisesRegex(ConfigurationError, 'Config must define "connections" section'): +# await Tortoise.init( +# { +# "apps": { +# "models": {"models": ["tests.testmodels"], "default_connection": "default"} +# } +# } +# ) +# +# async def test_init_no_apps(self): +# with self.assertRaisesRegex(ConfigurationError, 'Config must define "apps" section'): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# } +# } +# ) +# +# async def test_init_config_and_config_file(self): +# with self.assertRaisesRegex( +# ConfigurationError, 'You should init either from "config", "config_file" or "db_url"' +# ): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": {"models": ["tests.testmodels"], "default_connection": "default"} +# }, +# }, +# config_file="file.json", +# ) +# +# async def test_init_config_file_wrong_extension(self): +# with self.assertRaisesRegex( +# ConfigurationError, "Unknown config extension .ini, only .yml and .json are supported" +# ): +# await Tortoise.init(config_file="config.ini") +# +# @test.skipIf(os.name == "nt", "path issue on Windows") +# async def test_init_json_file(self): +# await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.json") +# self.assertIn("models", Tortoise.apps) +# self.assertIsNotNone(Tortoise.get_connection("default")) +# +# @test.skipIf(os.name == "nt", "path issue on Windows") +# async def test_init_yaml_file(self): +# await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.yaml") +# self.assertIn("models", Tortoise.apps) +# self.assertIsNotNone(Tortoise.get_connection("default")) +# +# async def test_generate_schema_without_init(self): +# with self.assertRaisesRegex( +# ConfigurationError, r"You have to call \.init\(\) first before generating schemas" +# ): +# await Tortoise.generate_schemas() +# +# async def test_drop_databases_without_init(self): +# with self.assertRaisesRegex( +# ConfigurationError, r"You have to call \.init\(\) first before deleting schemas" +# ): +# await Tortoise._drop_databases() +# +# async def test_bad_models(self): +# with self.assertRaisesRegex(ConfigurationError, 'Module "tests.testmodels2" not found'): +# await Tortoise.init( +# { +# "connections": { +# "default": { +# "engine": "tortoise.backends.sqlite", +# "credentials": {"file_path": ":memory:"}, +# } +# }, +# "apps": { +# "models": {"models": ["tests.testmodels2"], "default_connection": "default"} +# }, +# } +# ) diff --git a/tests/schema/test_generate_schema.py b/tests/schema/test_generate_schema.py index 17aafc939..dd0eec530 100644 --- a/tests/schema/test_generate_schema.py +++ b/tests/schema/test_generate_schema.py @@ -2,7 +2,7 @@ import re from unittest.mock import AsyncMock, patch -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.exceptions import ConfigurationError from tortoise.utils import get_schema_sql @@ -12,7 +12,6 @@ class TestGenerateSchema(test.SimpleTestCase): async def setUp(self): try: Tortoise.apps = {} - Tortoise._connections = {} Tortoise._inited = False except ConfigurationError: pass @@ -24,7 +23,6 @@ async def setUp(self): ] async def tearDown(self): - Tortoise._connections = {} await Tortoise._reset_apps() async def init_for(self, module: str, safe=False) -> None: @@ -42,7 +40,7 @@ async def init_for(self, module: str, safe=False) -> None: "apps": {"models": {"models": [module], "default_connection": "default"}}, } ) - self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split(";\n") + self.sqls = get_schema_sql(connections.get("default"), safe).split(";\n") def get_sql(self, text: str) -> str: return re.sub(r"[ \t\n\r]+", " ", " ".join([sql for sql in self.sqls if text in sql])) @@ -415,7 +413,7 @@ async def init_for(self, module: str, safe=False) -> None: "apps": {"models": {"models": [module], "default_connection": "default"}}, } ) - self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split("; ") + self.sqls = get_schema_sql(connections.get("default"), safe).split("; ") except ImportError: raise test.SkipTest("aiomysql not installed") @@ -498,7 +496,7 @@ async def test_schema_no_db_constraint(self): async def test_schema(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), """ @@ -793,7 +791,7 @@ async def init_for(self, module: str, safe=False) -> None: "apps": {"models": {"models": [module], "default_connection": "default"}}, } ) - self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split("; ") + self.sqls = get_schema_sql(connections.get("default"), safe).split("; ") except ImportError: raise test.SkipTest("asyncpg not installed") diff --git a/tests/test_two_databases.py b/tests/test_two_databases.py index 77b41bf5e..25d89378d 100644 --- a/tests/test_two_databases.py +++ b/tests/test_two_databases.py @@ -1,5 +1,5 @@ from tests.testmodels import Event, EventTwo, TeamTwo, Tournament -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.exceptions import OperationalError, ParamsError from tortoise.transactions import in_transaction @@ -17,8 +17,8 @@ async def setUp(self): } await Tortoise.init(merged_config, _create_db=True) await Tortoise.generate_schemas() - self.db = Tortoise.get_connection("models") - self.second_db = Tortoise.get_connection("events") + self.db = connections.get("models") + self.second_db = connections.get("events") async def tearDown(self): await Tortoise._drop_databases() diff --git a/tests/utils/test_run_async.py b/tests/utils/test_run_async.py index b74341be8..58a0fa213 100644 --- a/tests/utils/test_run_async.py +++ b/tests/utils/test_run_async.py @@ -1,7 +1,7 @@ import os from unittest import TestCase, skipIf -from tortoise import Tortoise, run_async +from tortoise import Tortoise, run_async, connections @skipIf(os.name == "nt", "stuck with Windows") @@ -12,25 +12,25 @@ def setUp(self): async def init(self): await Tortoise.init(db_url="sqlite://:memory:", modules={"models": []}) self.somevalue = 2 - self.assertNotEqual(Tortoise._connections, {}) + self.assertNotEqual(connections._get_storage(), {}) async def init_raise(self): await Tortoise.init(db_url="sqlite://:memory:", modules={"models": []}) self.somevalue = 3 - self.assertNotEqual(Tortoise._connections, {}) + self.assertNotEqual(connections._get_storage(), {}) raise Exception("Some exception") def test_run_async(self): - self.assertEqual(Tortoise._connections, {}) + self.assertEqual(connections._get_storage(), {}) self.assertEqual(self.somevalue, 1) run_async(self.init()) - self.assertEqual(Tortoise._connections, {}) + self.assertEqual(connections._get_storage(), {}) self.assertEqual(self.somevalue, 2) def test_run_async_raised(self): - self.assertEqual(Tortoise._connections, {}) + self.assertEqual(connections._get_storage(), {}) self.assertEqual(self.somevalue, 1) with self.assertRaises(Exception): run_async(self.init_raise()) - self.assertEqual(Tortoise._connections, {}) + self.assertEqual(connections._get_storage(), {}) self.assertEqual(self.somevalue, 3) diff --git a/tortoise/__init__.py b/tortoise/__init__.py index a1c028202..e32d82660 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -7,12 +7,15 @@ from copy import deepcopy from inspect import isclass from types import ModuleType -from typing import Coroutine, Dict, Iterable, List, Optional, Tuple, Type, Union, cast +from typing import ( + Coroutine, Dict, Iterable, List, Optional, Tuple, Type, Union, cast +) from pypika import Table from tortoise.backends.base.client import BaseDBAsyncClient from tortoise.backends.base.config_generator import expand_db_url, generate_config +from tortoise.connection import connections from tortoise.exceptions import ConfigurationError from tortoise.fields.relational import ( BackwardFKRelation, @@ -40,7 +43,7 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient: :raises KeyError: If connection name does not exist. """ - return cls._connections[connection_name] + return connections.get(connection_name) @classmethod def describe_model( @@ -542,7 +545,7 @@ async def init( :raises ConfigurationError: For any configuration error """ if cls._inited: - await cls.close_connections() + await connections.close_all(discard=True) await cls._reset_apps() if int(bool(config) + bool(config_file) + bool(db_url)) != 1: raise ConfigurationError( @@ -595,7 +598,7 @@ async def init( ) cls._init_timezone(use_tz, timezone) - await cls._init_connections(connections_config, _create_db) + await connections._init(connections_config, _create_db) cls._init_apps(apps_config) cls._init_routers(routers) @@ -643,7 +646,6 @@ async def _reset_apps(cls) -> None: if isinstance(model, ModelMeta): model._meta.default_connection = None cls.apps.clear() - current_transaction_map.clear() @classmethod async def generate_schemas(cls, safe: bool = True) -> None: @@ -658,7 +660,7 @@ async def generate_schemas(cls, safe: bool = True) -> None: """ if not cls._inited: raise ConfigurationError("You have to call .init() first before generating schemas") - for connection in cls._connections.values(): + for connection in connections.all(): await generate_schema_for_client(connection, safe) @classmethod @@ -671,10 +673,13 @@ async def _drop_databases(cls) -> None: """ if not cls._inited: raise ConfigurationError("You have to call .init() first before deleting schemas") - for connection in cls._connections.values(): - await connection.close() - await connection.db_delete() - cls._connections = {} + # this closes any existing connections/pool if any and clears + # the storage + await connections.close_all() + for conn in connections.all(): + await conn.db_delete() + connections.discard(conn.connection_name) + await cls._reset_apps() @classmethod @@ -706,7 +711,7 @@ async def do_stuff(): try: loop.run_until_complete(coro) finally: - loop.run_until_complete(Tortoise.close_connections()) + loop.run_until_complete(connections.close_all(discard=True)) __version__ = "0.18.0" diff --git a/tortoise/backends/asyncpg/client.py b/tortoise/backends/asyncpg/client.py index 7d9648c52..a6e2382f4 100644 --- a/tortoise/backends/asyncpg/client.py +++ b/tortoise/backends/asyncpg/client.py @@ -137,7 +137,7 @@ async def db_delete(self) -> None: await self.close() def acquire_connection(self) -> Union["PoolConnectionWrapper", "ConnectionWrapper"]: - return PoolConnectionWrapper(self._pool) + return PoolConnectionWrapper(self) def _in_transaction(self) -> "TransactionContext": return TransactionContextPooled(TransactionWrapper(self)) @@ -209,13 +209,13 @@ def __init__(self, connection: AsyncpgDBClient) -> None: self.connection_name = connection.connection_name self.transaction: Transaction = None self._finalized = False - self._parent = connection + self._parent: AsyncpgDBClient = connection def _in_transaction(self) -> "TransactionContext": return NestedTransactionPooledContext(self) def acquire_connection(self) -> "ConnectionWrapper": - return ConnectionWrapper(self._connection, self._lock) + return ConnectionWrapper(self._lock, self) @translate_exceptions async def execute_many(self, query: str, values: list) -> None: diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index 2f6248c9c..d35821ba6 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -3,11 +3,11 @@ from pypika import Query +from tortoise.connection import connections from tortoise.backends.base.executor import BaseExecutor from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.exceptions import TransactionManagementError from tortoise.log import db_client_logger -from tortoise.transactions import current_transaction_map class Capabilities: @@ -203,13 +203,20 @@ async def execute_query_dict(self, query: str, values: Optional[list] = None) -> class ConnectionWrapper: - __slots__ = ("connection", "lock") + __slots__ = ("connection", "lock", "client") - def __init__(self, connection: Any, lock: asyncio.Lock) -> None: - self.connection = connection + def __init__(self, lock: asyncio.Lock, client: Any) -> None: self.lock: asyncio.Lock = lock + self.client = client + self.connection: Any = client._connection + + async def ensure_connection(self): + if not self.connection: + await self.client.create_connection(with_db=True) + self.connection = self.client._connection async def __aenter__(self): + await self.ensure_connection() await self.lock.acquire() return self.connection @@ -225,10 +232,15 @@ def __init__(self, connection: Any) -> None: self.connection_name = connection.connection_name self.lock = getattr(connection, "_trxlock", None) + async def ensure_connection(self): + if not self.connection._connection: + await self.connection._parent.create_connection(with_db=True) + self.connection._connection = self.connection._parent._connection + async def __aenter__(self): + await self.ensure_connection() await self.lock.acquire() - current_transaction = current_transaction_map[self.connection_name] - self.token = current_transaction.set(self.connection) + self.token = connections.set(self.connection_name, self.connection) await self.connection.start() return self.connection @@ -240,16 +252,20 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.connection.rollback() else: await self.connection.commit() - current_transaction_map[self.connection_name].reset(self.token) + connections.reset(self.token) self.lock.release() class TransactionContextPooled(TransactionContext): - __slots__ = ("connection", "connection_name", "token") + __slots__ = ("conn_wrapper", "connection", "connection_name", "token") + + async def ensure_connection(self): + if not self.connection._parent._pool: + await self.connection._parent.create_connection(with_db=True) async def __aenter__(self): - current_transaction = current_transaction_map[self.connection_name] - self.token = current_transaction.set(self.connection) + await self.ensure_connection() + self.token = connections.set(self.connection_name, self.connection) self.connection._connection = await self.connection._parent._pool.acquire() await self.connection.start() return self.connection @@ -262,7 +278,6 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.connection.rollback() else: await self.connection.commit() - current_transaction_map[self.connection_name].reset(self.token) if self.connection._parent._pool: await self.connection._parent._pool.release(self.connection._connection) @@ -294,11 +309,18 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class PoolConnectionWrapper: - def __init__(self, pool: Any) -> None: - self.pool = pool + def __init__(self, client: Any) -> None: + self.pool = client._pool + self.client = client self.connection = None + async def ensure_connection(self): + if not self.pool: + await self.client.create_connection(with_db=True) + self.pool = self.client._pool + async def __aenter__(self): + await self.ensure_connection() # get first available connection self.connection = await self.pool.acquire() return self.connection diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index ea493f7e9..53fac1368 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -444,5 +444,4 @@ def get_create_schema_sql(self, safe: bool = True) -> str: return schema_creation_string async def generate_from_string(self, creation_string: str) -> None: - # print(creation_string) await self.client.execute_script(creation_string) diff --git a/tortoise/backends/mysql/client.py b/tortoise/backends/mysql/client.py index 675f27dc5..cdf54edc3 100644 --- a/tortoise/backends/mysql/client.py +++ b/tortoise/backends/mysql/client.py @@ -159,7 +159,7 @@ async def db_delete(self) -> None: await self.close() def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: - return PoolConnectionWrapper(self._pool) + return PoolConnectionWrapper(self) def _in_transaction(self) -> "TransactionContext": return TransactionContextPooled(TransactionWrapper(self)) @@ -229,7 +229,7 @@ def _in_transaction(self) -> "TransactionContext": return NestedTransactionPooledContext(self) def acquire_connection(self) -> ConnectionWrapper: - return ConnectionWrapper(self._connection, self._lock) + return ConnectionWrapper(self._lock, self) @translate_exceptions async def execute_many(self, query: str, values: list) -> None: diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index 5f631e029..fe2071d25 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -101,7 +101,7 @@ async def db_delete(self) -> None: raise e def acquire_connection(self) -> ConnectionWrapper: - return ConnectionWrapper(self._connection, self._lock) + return ConnectionWrapper(self._lock, self) def _in_transaction(self) -> "TransactionContext": return TransactionContext(TransactionWrapper(self)) @@ -160,6 +160,7 @@ def __init__(self, connection: SqliteClient) -> None: self.log = connection.log self._finalized = False self.fetch_inserted = connection.fetch_inserted + self._parent = connection def _in_transaction(self) -> "TransactionContext": return NestedTransactionContext(self) diff --git a/tortoise/connection.py b/tortoise/connection.py new file mode 100644 index 000000000..173ae0a4d --- /dev/null +++ b/tortoise/connection.py @@ -0,0 +1,123 @@ +import asyncio +import contextvars +import copy +import importlib +from contextvars import ContextVar +from typing import Dict, Optional, Any, Type, Union, List, TYPE_CHECKING + +from tortoise.backends.base.config_generator import expand_db_url +from tortoise.exceptions import ConfigurationError + + +if TYPE_CHECKING: + from tortoise.backends.base.client import BaseDBAsyncClient + + +class ConnectionHandler: + _conn_storage: ContextVar[Dict[str, 'BaseDBAsyncClient']] = contextvars.ContextVar( + "_conn_storage", default={} + ) + + def __init__(self): + self._db_config: Optional[Dict[str, Any]] = None + self._create_db: bool = False + + async def _init(self, db_config: Dict[str, Any], create_db: bool): + self._db_config = db_config + self._create_db = create_db + await self._init_connections() + + @property + def db_config(self) -> Dict[str, Any]: + return self._db_config + + def _get_storage(self) -> Dict[str, 'BaseDBAsyncClient']: + return self._conn_storage.get() + + def _copy_storage(self): + return copy.copy(self._get_storage()) + + def _clear_storage(self) -> None: + self._get_storage().clear() + + def _discover_client_class(self, engine: str) -> Type['BaseDBAsyncClient']: + # Let exception bubble up for transparency + engine_module = importlib.import_module(engine) + + try: + client_class = engine_module.client_class # type: ignore + except AttributeError: + raise ConfigurationError(f'Backend for engine "{engine}" does not implement db client') + return client_class + + def _get_db_info(self, conn_alias: str) -> Union[str, Dict]: + try: + return self._db_config[conn_alias] + except KeyError: + raise ConfigurationError( + f'Unable to get db settings for alias {conn_alias}. Please ' + f'check if the config dict contains this alias and try again' + ) + + async def _init_connections(self) -> None: + # print("storage within init is ", self._get_storage()) + for alias in self.db_config: + connection: 'BaseDBAsyncClient' = self.get(alias) + if self._create_db: + await connection.db_create() + + def _create_connection(self, conn_alias: str) -> 'BaseDBAsyncClient': + db_info: Union[str, Dict] = self._get_db_info(conn_alias) + if isinstance(db_info, str): + db_info = expand_db_url(db_info) + client_class: Type['BaseDBAsyncClient'] = self._discover_client_class( + db_info.get("engine") + ) + db_params = db_info["credentials"].copy() + db_params.update({"connection_name": conn_alias}) + connection: 'BaseDBAsyncClient' = client_class(**db_params) + return connection + + def get(self, conn_alias: str) -> 'BaseDBAsyncClient': + storage: Dict[str, 'BaseDBAsyncClient'] = self._get_storage() + try: + return storage[conn_alias] + except KeyError: + connection: BaseDBAsyncClient = self._create_connection(conn_alias) + storage[conn_alias] = connection + return connection + + def _set_storage(self, new_storage: Dict[str, 'BaseDBAsyncClient']) -> contextvars.Token: + # Should be used only for testing purposes. + return self._conn_storage.set(new_storage) + + def set(self, conn_alias: str, conn) -> contextvars.Token: + storage_copy = self._copy_storage() + storage_copy[conn_alias] = conn + return self._conn_storage.set(storage_copy) + + def discard(self, conn_alias: str) -> Optional['BaseDBAsyncClient']: + return self._get_storage().pop(conn_alias, None) + + def reset(self, token: contextvars.Token): + current_storage = self._get_storage() + self._conn_storage.reset(token) + prev_storage = self._get_storage() + for alias, conn in current_storage.items(): + if alias not in prev_storage: + prev_storage[alias] = conn + + def all(self) -> List['BaseDBAsyncClient']: + # Returning a list here so as to avoid accidental + # mutation of the underlying storage dict + return list(self._get_storage().values()) + + async def close_all(self, discard: bool = False) -> None: + tasks = [conn.close() for conn in self._get_storage().values()] + await asyncio.gather(*tasks) + if discard: + for alias in tuple(self._get_storage()): + self.discard(alias) + + +connections = ConnectionHandler() \ No newline at end of file diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index 6e7010afc..38759491f 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import os as _os import unittest from asyncio.events import AbstractEventLoop @@ -8,10 +9,9 @@ from unittest import SkipTest, expectedFailure, skip, skipIf, skipUnless from unittest.result import TestResult -from tortoise import Model, Tortoise +from tortoise import Model, Tortoise, connections from tortoise.backends.base.config_generator import generate_config as _generate_config from tortoise.exceptions import DBConnectionError -from tortoise.transactions import current_transaction_map __all__ = ( "SimpleTestCase", @@ -43,7 +43,7 @@ _SELECTOR = None _LOOP: AbstractEventLoop = None # type: ignore _MODULES: Iterable[Union[str, ModuleType]] = [] -_CONN_MAP: dict = {} +_CONN_CONFIG: dict = {} def getDBConfig(app_label: str, modules: Iterable[Union[str, ModuleType]]) -> dict: @@ -74,8 +74,8 @@ async def _init_db(config: dict) -> None: def _restore_default() -> None: Tortoise.apps = {} - Tortoise._connections = _CONNECTIONS.copy() - current_transaction_map.update(_CONN_MAP) + connections._get_storage().update(_CONNECTIONS.copy()) + connections._db_config = _CONN_CONFIG.copy() Tortoise._init_apps(_CONFIG["apps"]) Tortoise._inited = True @@ -101,20 +101,20 @@ def initializer( global _LOOP global _TORTOISE_TEST_DB global _MODULES - global _CONN_MAP + global _CONN_CONFIG _MODULES = modules if db_url is not None: # pragma: nobranch _TORTOISE_TEST_DB = db_url _CONFIG = getDBConfig(app_label=app_label, modules=_MODULES) - loop = loop or asyncio.get_event_loop() _LOOP = loop _SELECTOR = loop._selector # type: ignore loop.run_until_complete(_init_db(_CONFIG)) - _CONNECTIONS = Tortoise._connections.copy() - _CONN_MAP = current_transaction_map.copy() + _CONNECTIONS = connections._copy_storage() + _CONN_CONFIG = connections.db_config.copy() + connections._clear_storage() + connections.db_config.clear() Tortoise.apps = {} - Tortoise._connections = {} Tortoise._inited = False @@ -175,12 +175,17 @@ def _init_loop(self) -> None: self.loop = asyncio.new_event_loop() async def _setUpDB(self) -> None: - pass + # setting storage to an empty dict explicitly to create a + # ContextVar specific scope for each test case. The storage will + # either be restored from _restore_default or re-initialised depending on + # how the test case should behave + self.token = connections._set_storage({}) async def _tearDownDB(self) -> None: pass async def _setUp(self) -> None: + self.token: Optional[contextvars.Token] = None await self._setUpDB() if asyncio.iscoroutinefunction(self.setUp): @@ -191,14 +196,24 @@ async def _setUp(self) -> None: # don't take into account if the loop ran during setUp self.loop._asynctest_ran = False # type: ignore + def _reset_conn_state(self): + # clearing the storage and restoring to previous storage state + # of the contextvar + connections._clear_storage() + connections._db_config.clear() + if self.token: + connections.reset(self.token) + async def _tearDown(self) -> None: if asyncio.iscoroutinefunction(self.tearDown): await self.asyncTearDown() else: self.tearDown() await self._tearDownDB() + + self._reset_conn_state() + Tortoise.apps = {} - Tortoise._connections = {} Tortoise._inited = False # Override unittest.TestCase methods which call setUp() and tearDown() @@ -301,13 +316,12 @@ class IsolatedTestCase(SimpleTestCase): tortoise_test_modules: Iterable[Union[str, ModuleType]] = [] async def _setUpDB(self) -> None: + await super()._setUpDB() config = getDBConfig(app_label="models", modules=self.tortoise_test_modules or _MODULES) await Tortoise.init(config, _create_db=True) await Tortoise.generate_schemas(safe=False) - self._connections = Tortoise._connections.copy() async def _tearDownDB(self) -> None: - Tortoise._connections = self._connections.copy() await Tortoise._drop_databases() @@ -336,25 +350,39 @@ async def _tearDownDB(self) -> None: class TransactionTestContext: - __slots__ = ("connection", "connection_name", "token") + __slots__ = ("connection", "connection_name", "token", "uses_pool") def __init__(self, connection) -> None: self.connection = connection self.connection_name = connection.connection_name + self.uses_pool = hasattr(self.connection._parent, "_pool") - async def __aenter__(self): - current_transaction = current_transaction_map[self.connection_name] - self.token = current_transaction.set(self.connection) - if hasattr(self.connection, "_parent"): + async def ensure_connection(self): + is_conn_established = self.connection._connection != None + if self.uses_pool: + is_conn_established = self.connection._parent._pool != None + + # If the underlying pool/connection hasn't been established then + # first create the pool/connection + if not is_conn_established: + await self.connection._parent.create_connection(with_db=True) + + if self.uses_pool: self.connection._connection = await self.connection._parent._pool.acquire() + else: + self.connection._connection = self.connection._parent._connection + + async def __aenter__(self): + await self.ensure_connection() + self.token = connections.set(self.connection_name, self.connection) await self.connection.start() return self.connection async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.connection.rollback() - if hasattr(self.connection, "_parent"): + if self.uses_pool: await self.connection._parent._pool.release(self.connection._connection) - current_transaction_map[self.connection_name].reset(self.token) + connections.reset(self.token) class TestCase(TruncationTestCase): @@ -375,6 +403,13 @@ async def _run_outcome(self, outcome, expecting_failure, testMethod) -> None: else: await super()._run_outcome(outcome, expecting_failure, testMethod) + # Clearing out the storage so as to provide a clean storage state + # for all tests. Have to explicitly clear it here since teardown + # gets called as part of _run_outcome which would only clear the nested + # storage in the contextvar. This is not needed for other testcase classes + # as they don't run inside a transaction + self._reset_conn_state() + async def _setUpDB(self) -> None: pass diff --git a/tortoise/models.py b/tortoise/models.py index 76d1bc625..9981d8146 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -22,6 +22,7 @@ from pypika import Order, Query, Table from pypika.terms import Term +from tortoise import connections from tortoise.backends.base.client import BaseDBAsyncClient from tortoise.exceptions import ( ConfigurationError, @@ -270,10 +271,7 @@ def add_field(self, name: str, value: Field) -> None: @property def db(self) -> BaseDBAsyncClient: - try: - return current_transaction_map[self.default_connection].get() - except KeyError: - raise ConfigurationError("No DB associated to model") + return connections.get(self.default_connection) @property def ordering(self) -> Tuple[Tuple[str, Order], ...]: diff --git a/tortoise/router.py b/tortoise/router.py index 885398d70..f2e14d3a9 100644 --- a/tortoise/router.py +++ b/tortoise/router.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, List, Optional, Type -from tortoise.exceptions import ParamsError -from tortoise.transactions import get_connection +from tortoise.exceptions import ConfigurationError +from tortoise.connection import connections if TYPE_CHECKING: from tortoise import BaseDBAsyncClient, Model @@ -28,8 +28,8 @@ def _router_func(self, model: Type["Model"], action: str): def _db_route(self, model: Type["Model"], action: str): try: - return get_connection(self._router_func(model, action)) - except ParamsError: + return connections.get(self._router_func(model, action)) + except ConfigurationError: return None def db_for_read(self, model: Type["Model"]) -> Optional["BaseDBAsyncClient"]: diff --git a/tortoise/transactions.py b/tortoise/transactions.py index 02c41f122..dd86c28d1 100644 --- a/tortoise/transactions.py +++ b/tortoise/transactions.py @@ -1,6 +1,7 @@ from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast +from tortoise import connections from tortoise.exceptions import ParamsError current_transaction_map: dict = {} @@ -12,18 +13,17 @@ F = TypeVar("F", bound=FuncType) -def get_connection(connection_name: Optional[str]) -> "BaseDBAsyncClient": - from tortoise import Tortoise +def _get_connection(connection_name: Optional[str]) -> "BaseDBAsyncClient": if connection_name: - connection = current_transaction_map[connection_name].get() - elif len(Tortoise._connections) == 1: - connection_name = list(Tortoise._connections.keys())[0] - connection = current_transaction_map[connection_name].get() + connection = connections.get(connection_name) + elif len(connections.db_config) == 1: + connection_name = next(iter(connections.db_config.keys())) + connection = connections.get(connection_name) else: raise ParamsError( "You are running with multiple databases, so you should specify" - f" connection_name: {list(Tortoise._connections.keys())}" + f" connection_name: {list(connections.db_config)}" ) return connection @@ -38,7 +38,7 @@ def in_transaction(connection_name: Optional[str] = None) -> "TransactionContext :param connection_name: name of connection to run with, optional if you have only one db connection """ - connection = get_connection(connection_name) + connection = _get_connection(connection_name) return connection._in_transaction() From c5c7d31dd18af9caf75cc66d30be47a68d98716e Mon Sep 17 00:00:00 2001 From: Aditya N Date: Sun, 5 Dec 2021 17:14:15 +0530 Subject: [PATCH 02/25] - Refactored test_init.py to conform to the new connection management scheme --- tests/model_setup/test_init.py | 720 ++++++++++++++++----------------- 1 file changed, 359 insertions(+), 361 deletions(-) diff --git a/tests/model_setup/test_init.py b/tests/model_setup/test_init.py index fb4393a1a..d9bcdf66e 100644 --- a/tests/model_setup/test_init.py +++ b/tests/model_setup/test_init.py @@ -1,361 +1,359 @@ -# import os -# -# from tortoise import Tortoise -# from tortoise.contrib import test -# from tortoise.exceptions import ConfigurationError -# -# -# class TestInitErrors(test.SimpleTestCase): -# async def setUp(self): -# try: -# Tortoise.apps = {} -# Tortoise._connections = {} -# Tortoise._inited = False -# except ConfigurationError: -# pass -# Tortoise._inited = False -# -# async def tearDown(self): -# await Tortoise.close_connections() -# await Tortoise._reset_apps() -# -# async def test_basic_init(self): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": {"models": ["tests.testmodels"], "default_connection": "default"} -# }, -# } -# ) -# self.assertIn("models", Tortoise.apps) -# self.assertIsNotNone(Tortoise.get_connection("default")) -# -# async def test_empty_modules_init(self): -# with self.assertWarnsRegex(RuntimeWarning, 'Module "tests.model_setup" has no models'): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": {"models": ["tests.model_setup"], "default_connection": "default"} -# }, -# } -# ) -# -# async def test_dup1_init(self): -# with self.assertRaisesRegex( -# ConfigurationError, 'backward relation "events" duplicates in model Tournament' -# ): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": { -# "models": ["tests.model_setup.models_dup1"], -# "default_connection": "default", -# } -# }, -# } -# ) -# -# async def test_dup2_init(self): -# with self.assertRaisesRegex( -# ConfigurationError, 'backward relation "events" duplicates in model Team' -# ): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": { -# "models": ["tests.model_setup.models_dup2"], -# "default_connection": "default", -# } -# }, -# } -# ) -# -# async def test_dup3_init(self): -# with self.assertRaisesRegex( -# ConfigurationError, 'backward relation "event" duplicates in model Tournament' -# ): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": { -# "models": ["tests.model_setup.models_dup3"], -# "default_connection": "default", -# } -# }, -# } -# ) -# -# async def test_generated_nonint(self): -# with self.assertRaisesRegex( -# ConfigurationError, "Field 'val' \\(CharField\\) can't be DB-generated" -# ): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": { -# "models": ["tests.model_setup.model_generated_nonint"], -# "default_connection": "default", -# } -# }, -# } -# ) -# -# async def test_multiple_pk(self): -# with self.assertRaisesRegex( -# ConfigurationError, -# "Can't create model Tournament with two primary keys, only single primary key is supported", -# ): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": { -# "models": ["tests.model_setup.model_multiple_pk"], -# "default_connection": "default", -# } -# }, -# } -# ) -# -# async def test_nonpk_id(self): -# with self.assertRaisesRegex( -# ConfigurationError, -# "Can't create model Tournament without explicit primary key if" -# " field 'id' already present", -# ): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": { -# "models": ["tests.model_setup.model_nonpk_id"], -# "default_connection": "default", -# } -# }, -# } -# ) -# -# async def test_unknown_connection(self): -# with self.assertRaisesRegex(ConfigurationError, 'Unknown connection "fioop"'): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": {"models": ["tests.testmodels"], "default_connection": "fioop"} -# }, -# } -# ) -# -# async def test_url_without_modules(self): -# with self.assertRaisesRegex( -# ConfigurationError, 'You must specify "db_url" and "modules" together' -# ): -# await Tortoise.init(db_url=f"sqlite://{':memory:'}") -# -# async def test_default_connection_init(self): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": {"models": {"models": ["tests.testmodels"]}}, -# } -# ) -# self.assertIn("models", Tortoise.apps) -# self.assertIsNotNone(Tortoise.get_connection("default")) -# -# async def test_db_url_init(self): -# await Tortoise.init( -# { -# "connections": {"default": f"sqlite://{':memory:'}"}, -# "apps": { -# "models": {"models": ["tests.testmodels"], "default_connection": "default"} -# }, -# } -# ) -# self.assertIn("models", Tortoise.apps) -# self.assertIsNotNone(Tortoise.get_connection("default")) -# -# async def test_shorthand_init(self): -# await Tortoise.init( -# db_url=f"sqlite://{':memory:'}", modules={"models": ["tests.testmodels"]} -# ) -# self.assertIn("models", Tortoise.apps) -# self.assertIsNotNone(Tortoise.get_connection("default")) -# -# async def test_init_wrong_connection_engine(self): -# with self.assertRaisesRegex(ImportError, "tortoise.backends.test"): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.test", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": {"models": ["tests.testmodels"], "default_connection": "default"} -# }, -# } -# ) -# -# async def test_init_wrong_connection_engine_2(self): -# with self.assertRaisesRegex( -# ConfigurationError, -# 'Backend for engine "tortoise.backends" does not implement db client', -# ): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": {"models": ["tests.testmodels"], "default_connection": "default"} -# }, -# } -# ) -# -# async def test_init_no_connections(self): -# with self.assertRaisesRegex(ConfigurationError, 'Config must define "connections" section'): -# await Tortoise.init( -# { -# "apps": { -# "models": {"models": ["tests.testmodels"], "default_connection": "default"} -# } -# } -# ) -# -# async def test_init_no_apps(self): -# with self.assertRaisesRegex(ConfigurationError, 'Config must define "apps" section'): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# } -# } -# ) -# -# async def test_init_config_and_config_file(self): -# with self.assertRaisesRegex( -# ConfigurationError, 'You should init either from "config", "config_file" or "db_url"' -# ): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": {"models": ["tests.testmodels"], "default_connection": "default"} -# }, -# }, -# config_file="file.json", -# ) -# -# async def test_init_config_file_wrong_extension(self): -# with self.assertRaisesRegex( -# ConfigurationError, "Unknown config extension .ini, only .yml and .json are supported" -# ): -# await Tortoise.init(config_file="config.ini") -# -# @test.skipIf(os.name == "nt", "path issue on Windows") -# async def test_init_json_file(self): -# await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.json") -# self.assertIn("models", Tortoise.apps) -# self.assertIsNotNone(Tortoise.get_connection("default")) -# -# @test.skipIf(os.name == "nt", "path issue on Windows") -# async def test_init_yaml_file(self): -# await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.yaml") -# self.assertIn("models", Tortoise.apps) -# self.assertIsNotNone(Tortoise.get_connection("default")) -# -# async def test_generate_schema_without_init(self): -# with self.assertRaisesRegex( -# ConfigurationError, r"You have to call \.init\(\) first before generating schemas" -# ): -# await Tortoise.generate_schemas() -# -# async def test_drop_databases_without_init(self): -# with self.assertRaisesRegex( -# ConfigurationError, r"You have to call \.init\(\) first before deleting schemas" -# ): -# await Tortoise._drop_databases() -# -# async def test_bad_models(self): -# with self.assertRaisesRegex(ConfigurationError, 'Module "tests.testmodels2" not found'): -# await Tortoise.init( -# { -# "connections": { -# "default": { -# "engine": "tortoise.backends.sqlite", -# "credentials": {"file_path": ":memory:"}, -# } -# }, -# "apps": { -# "models": {"models": ["tests.testmodels2"], "default_connection": "default"} -# }, -# } -# ) +import os + +from tortoise import Tortoise, connections +from tortoise.contrib import test +from tortoise.exceptions import ConfigurationError + + +class TestInitErrors(test.SimpleTestCase): + async def setUp(self): + try: + Tortoise.apps = {} + Tortoise._inited = False + except ConfigurationError: + pass + Tortoise._inited = False + + async def tearDown(self): + await Tortoise._reset_apps() + + async def test_basic_init(self): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": {"models": ["tests.testmodels"], "default_connection": "default"} + }, + } + ) + self.assertIn("models", Tortoise.apps) + self.assertIsNotNone(connections.get("default")) + + async def test_empty_modules_init(self): + with self.assertWarnsRegex(RuntimeWarning, 'Module "tests.model_setup" has no models'): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": {"models": ["tests.model_setup"], "default_connection": "default"} + }, + } + ) + + async def test_dup1_init(self): + with self.assertRaisesRegex( + ConfigurationError, 'backward relation "events" duplicates in model Tournament' + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_setup.models_dup1"], + "default_connection": "default", + } + }, + } + ) + + async def test_dup2_init(self): + with self.assertRaisesRegex( + ConfigurationError, 'backward relation "events" duplicates in model Team' + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_setup.models_dup2"], + "default_connection": "default", + } + }, + } + ) + + async def test_dup3_init(self): + with self.assertRaisesRegex( + ConfigurationError, 'backward relation "event" duplicates in model Tournament' + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_setup.models_dup3"], + "default_connection": "default", + } + }, + } + ) + + async def test_generated_nonint(self): + with self.assertRaisesRegex( + ConfigurationError, "Field 'val' \\(CharField\\) can't be DB-generated" + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_setup.model_generated_nonint"], + "default_connection": "default", + } + }, + } + ) + + async def test_multiple_pk(self): + with self.assertRaisesRegex( + ConfigurationError, + "Can't create model Tournament with two primary keys, only single primary key is supported", + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_setup.model_multiple_pk"], + "default_connection": "default", + } + }, + } + ) + + async def test_nonpk_id(self): + with self.assertRaisesRegex( + ConfigurationError, + "Can't create model Tournament without explicit primary key if" + " field 'id' already present", + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_setup.model_nonpk_id"], + "default_connection": "default", + } + }, + } + ) + + async def test_unknown_connection(self): + with self.assertRaisesRegex(ConfigurationError, 'Unknown connection "fioop"'): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": {"models": ["tests.testmodels"], "default_connection": "fioop"} + }, + } + ) + + async def test_url_without_modules(self): + with self.assertRaisesRegex( + ConfigurationError, 'You must specify "db_url" and "modules" together' + ): + await Tortoise.init(db_url=f"sqlite://{':memory:'}") + + async def test_default_connection_init(self): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": {"models": {"models": ["tests.testmodels"]}}, + } + ) + self.assertIn("models", Tortoise.apps) + self.assertIsNotNone(connections.get("default")) + + async def test_db_url_init(self): + await Tortoise.init( + { + "connections": {"default": f"sqlite://{':memory:'}"}, + "apps": { + "models": {"models": ["tests.testmodels"], "default_connection": "default"} + }, + } + ) + self.assertIn("models", Tortoise.apps) + self.assertIsNotNone(connections.get("default")) + + async def test_shorthand_init(self): + await Tortoise.init( + db_url=f"sqlite://{':memory:'}", modules={"models": ["tests.testmodels"]} + ) + self.assertIn("models", Tortoise.apps) + self.assertIsNotNone(connections.get("default")) + + async def test_init_wrong_connection_engine(self): + with self.assertRaisesRegex(ImportError, "tortoise.backends.test"): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.test", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": {"models": ["tests.testmodels"], "default_connection": "default"} + }, + } + ) + + async def test_init_wrong_connection_engine_2(self): + with self.assertRaisesRegex( + ConfigurationError, + 'Backend for engine "tortoise.backends" does not implement db client', + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": {"models": ["tests.testmodels"], "default_connection": "default"} + }, + } + ) + + async def test_init_no_connections(self): + with self.assertRaisesRegex(ConfigurationError, 'Config must define "connections" section'): + await Tortoise.init( + { + "apps": { + "models": {"models": ["tests.testmodels"], "default_connection": "default"} + } + } + ) + + async def test_init_no_apps(self): + with self.assertRaisesRegex(ConfigurationError, 'Config must define "apps" section'): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + } + } + ) + + async def test_init_config_and_config_file(self): + with self.assertRaisesRegex( + ConfigurationError, 'You should init either from "config", "config_file" or "db_url"' + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": {"models": ["tests.testmodels"], "default_connection": "default"} + }, + }, + config_file="file.json", + ) + + async def test_init_config_file_wrong_extension(self): + with self.assertRaisesRegex( + ConfigurationError, "Unknown config extension .ini, only .yml and .json are supported" + ): + await Tortoise.init(config_file="config.ini") + + @test.skipIf(os.name == "nt", "path issue on Windows") + async def test_init_json_file(self): + await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.json") + self.assertIn("models", Tortoise.apps) + self.assertIsNotNone(connections.get("default")) + + @test.skipIf(os.name == "nt", "path issue on Windows") + async def test_init_yaml_file(self): + await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.yaml") + self.assertIn("models", Tortoise.apps) + self.assertIsNotNone(connections.get("default")) + + async def test_generate_schema_without_init(self): + with self.assertRaisesRegex( + ConfigurationError, r"You have to call \.init\(\) first before generating schemas" + ): + await Tortoise.generate_schemas() + + async def test_drop_databases_without_init(self): + with self.assertRaisesRegex( + ConfigurationError, r"You have to call \.init\(\) first before deleting schemas" + ): + await Tortoise._drop_databases() + + async def test_bad_models(self): + with self.assertRaisesRegex(ConfigurationError, 'Module "tests.testmodels2" not found'): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": {"models": ["tests.testmodels2"], "default_connection": "default"} + }, + } + ) From 1fd9daca9f73e5484e6ecd2d30077780325e0baa Mon Sep 17 00:00:00 2001 From: Aditya N Date: Sun, 5 Dec 2021 18:32:52 +0530 Subject: [PATCH 03/25] - Fixed linting errors (flake8 and mypy) --- tests/utils/test_run_async.py | 2 +- tortoise/__init__.py | 4 +-- tortoise/backends/base/client.py | 10 +++--- tortoise/connection.py | 55 ++++++++++++++++--------------- tortoise/contrib/test/__init__.py | 31 +++++++++-------- tortoise/models.py | 6 +++- tortoise/router.py | 2 +- 7 files changed, 57 insertions(+), 53 deletions(-) diff --git a/tests/utils/test_run_async.py b/tests/utils/test_run_async.py index 58a0fa213..337c1c39c 100644 --- a/tests/utils/test_run_async.py +++ b/tests/utils/test_run_async.py @@ -1,7 +1,7 @@ import os from unittest import TestCase, skipIf -from tortoise import Tortoise, run_async, connections +from tortoise import Tortoise, connections, run_async @skipIf(os.name == "nt", "stuck with Windows") diff --git a/tortoise/__init__.py b/tortoise/__init__.py index e32d82660..ff98fe6ae 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -7,9 +7,7 @@ from copy import deepcopy from inspect import isclass from types import ModuleType -from typing import ( - Coroutine, Dict, Iterable, List, Optional, Tuple, Type, Union, cast -) +from typing import Coroutine, Dict, Iterable, List, Optional, Tuple, Type, Union, cast from pypika import Table diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index d35821ba6..9c136a8bc 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -3,9 +3,9 @@ from pypika import Query -from tortoise.connection import connections from tortoise.backends.base.executor import BaseExecutor from tortoise.backends.base.schema_generator import BaseSchemaGenerator +from tortoise.connection import connections from tortoise.exceptions import TransactionManagementError from tortoise.log import db_client_logger @@ -210,7 +210,7 @@ def __init__(self, lock: asyncio.Lock, client: Any) -> None: self.client = client self.connection: Any = client._connection - async def ensure_connection(self): + async def ensure_connection(self) -> None: if not self.connection: await self.client.create_connection(with_db=True) self.connection = self.client._connection @@ -232,7 +232,7 @@ def __init__(self, connection: Any) -> None: self.connection_name = connection.connection_name self.lock = getattr(connection, "_trxlock", None) - async def ensure_connection(self): + async def ensure_connection(self) -> None: if not self.connection._connection: await self.connection._parent.create_connection(with_db=True) self.connection._connection = self.connection._parent._connection @@ -259,7 +259,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class TransactionContextPooled(TransactionContext): __slots__ = ("conn_wrapper", "connection", "connection_name", "token") - async def ensure_connection(self): + async def ensure_connection(self) -> None: if not self.connection._parent._pool: await self.connection._parent.create_connection(with_db=True) @@ -314,7 +314,7 @@ def __init__(self, client: Any) -> None: self.client = client self.connection = None - async def ensure_connection(self): + async def ensure_connection(self) -> None: if not self.pool: await self.client.create_connection(with_db=True) self.pool = self.client._pool diff --git a/tortoise/connection.py b/tortoise/connection.py index 173ae0a4d..979bcb697 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -3,23 +3,23 @@ import copy import importlib from contextvars import ContextVar -from typing import Dict, Optional, Any, Type, Union, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union from tortoise.backends.base.config_generator import expand_db_url from tortoise.exceptions import ConfigurationError - if TYPE_CHECKING: from tortoise.backends.base.client import BaseDBAsyncClient + DBConfigType = Dict[str, Any] class ConnectionHandler: - _conn_storage: ContextVar[Dict[str, 'BaseDBAsyncClient']] = contextvars.ContextVar( + _conn_storage: ContextVar[Dict[str, "BaseDBAsyncClient"]] = contextvars.ContextVar( "_conn_storage", default={} ) - def __init__(self): - self._db_config: Optional[Dict[str, Any]] = None + def __init__(self) -> None: + self._db_config: Optional["DBConfigType"] = None self._create_db: bool = False async def _init(self, db_config: Dict[str, Any], create_db: bool): @@ -28,19 +28,25 @@ async def _init(self, db_config: Dict[str, Any], create_db: bool): await self._init_connections() @property - def db_config(self) -> Dict[str, Any]: + def db_config(self) -> "DBConfigType": + if self._db_config is None: + raise ConfigurationError( + "DB configuration not initialised. Make sure to call " + "Tortoise.init with a valid configuration before attempting " + "to create connections." + ) return self._db_config - def _get_storage(self) -> Dict[str, 'BaseDBAsyncClient']: + def _get_storage(self) -> Dict[str, "BaseDBAsyncClient"]: return self._conn_storage.get() - def _copy_storage(self): + def _copy_storage(self) -> Dict[str, "BaseDBAsyncClient"]: return copy.copy(self._get_storage()) def _clear_storage(self) -> None: self._get_storage().clear() - def _discover_client_class(self, engine: str) -> Type['BaseDBAsyncClient']: + def _discover_client_class(self, engine: str) -> Type["BaseDBAsyncClient"]: # Let exception bubble up for transparency engine_module = importlib.import_module(engine) @@ -52,34 +58,31 @@ def _discover_client_class(self, engine: str) -> Type['BaseDBAsyncClient']: def _get_db_info(self, conn_alias: str) -> Union[str, Dict]: try: - return self._db_config[conn_alias] + return self.db_config[conn_alias] except KeyError: raise ConfigurationError( - f'Unable to get db settings for alias {conn_alias}. Please ' - f'check if the config dict contains this alias and try again' + f"Unable to get db settings for alias {conn_alias}. Please " + f"check if the config dict contains this alias and try again" ) async def _init_connections(self) -> None: - # print("storage within init is ", self._get_storage()) for alias in self.db_config: - connection: 'BaseDBAsyncClient' = self.get(alias) + connection: "BaseDBAsyncClient" = self.get(alias) if self._create_db: await connection.db_create() - def _create_connection(self, conn_alias: str) -> 'BaseDBAsyncClient': - db_info: Union[str, Dict] = self._get_db_info(conn_alias) + def _create_connection(self, conn_alias: str) -> "BaseDBAsyncClient": + db_info = self._get_db_info(conn_alias) if isinstance(db_info, str): db_info = expand_db_url(db_info) - client_class: Type['BaseDBAsyncClient'] = self._discover_client_class( - db_info.get("engine") - ) + client_class = self._discover_client_class(db_info.get("engine", "")) db_params = db_info["credentials"].copy() db_params.update({"connection_name": conn_alias}) - connection: 'BaseDBAsyncClient' = client_class(**db_params) + connection: "BaseDBAsyncClient" = client_class(**db_params) return connection - def get(self, conn_alias: str) -> 'BaseDBAsyncClient': - storage: Dict[str, 'BaseDBAsyncClient'] = self._get_storage() + def get(self, conn_alias: str) -> "BaseDBAsyncClient": + storage: Dict[str, "BaseDBAsyncClient"] = self._get_storage() try: return storage[conn_alias] except KeyError: @@ -87,7 +90,7 @@ def get(self, conn_alias: str) -> 'BaseDBAsyncClient': storage[conn_alias] = connection return connection - def _set_storage(self, new_storage: Dict[str, 'BaseDBAsyncClient']) -> contextvars.Token: + def _set_storage(self, new_storage: Dict[str, "BaseDBAsyncClient"]) -> contextvars.Token: # Should be used only for testing purposes. return self._conn_storage.set(new_storage) @@ -96,7 +99,7 @@ def set(self, conn_alias: str, conn) -> contextvars.Token: storage_copy[conn_alias] = conn return self._conn_storage.set(storage_copy) - def discard(self, conn_alias: str) -> Optional['BaseDBAsyncClient']: + def discard(self, conn_alias: str) -> Optional["BaseDBAsyncClient"]: return self._get_storage().pop(conn_alias, None) def reset(self, token: contextvars.Token): @@ -107,7 +110,7 @@ def reset(self, token: contextvars.Token): if alias not in prev_storage: prev_storage[alias] = conn - def all(self) -> List['BaseDBAsyncClient']: + def all(self) -> List["BaseDBAsyncClient"]: # Returning a list here so as to avoid accidental # mutation of the underlying storage dict return list(self._get_storage().values()) @@ -120,4 +123,4 @@ async def close_all(self, discard: bool = False) -> None: self.discard(alias) -connections = ConnectionHandler() \ No newline at end of file +connections = ConnectionHandler() diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index 38759491f..15599f211 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -174,19 +174,8 @@ def _init_loop(self) -> None: else: # pragma: nocoverage self.loop = asyncio.new_event_loop() - async def _setUpDB(self) -> None: - # setting storage to an empty dict explicitly to create a - # ContextVar specific scope for each test case. The storage will - # either be restored from _restore_default or re-initialised depending on - # how the test case should behave - self.token = connections._set_storage({}) - - async def _tearDownDB(self) -> None: - pass - async def _setUp(self) -> None: self.token: Optional[contextvars.Token] = None - await self._setUpDB() if asyncio.iscoroutinefunction(self.setUp): await self.asyncSetUp() @@ -196,11 +185,21 @@ async def _setUp(self) -> None: # don't take into account if the loop ran during setUp self.loop._asynctest_ran = False # type: ignore - def _reset_conn_state(self): + async def _setUpDB(self) -> None: + # setting storage to an empty dict explicitly to create a + # ContextVar specific scope for each test case. The storage will + # either be restored from _restore_default or re-initialised depending on + # how the test case should behave + self.token = connections._set_storage({}) + + async def _tearDownDB(self) -> None: + pass + + def _reset_conn_state(self) -> None: # clearing the storage and restoring to previous storage state # of the contextvar connections._clear_storage() - connections._db_config.clear() + connections.db_config.clear() if self.token: connections.reset(self.token) @@ -357,10 +356,10 @@ def __init__(self, connection) -> None: self.connection_name = connection.connection_name self.uses_pool = hasattr(self.connection._parent, "_pool") - async def ensure_connection(self): - is_conn_established = self.connection._connection != None + async def ensure_connection(self) -> None: + is_conn_established = self.connection._connection is not None if self.uses_pool: - is_conn_established = self.connection._parent._pool != None + is_conn_established = self.connection._parent._pool is not None # If the underlying pool/connection hasn't been established then # first create the pool/connection diff --git a/tortoise/models.py b/tortoise/models.py index 9981d8146..bb16c70ee 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -52,7 +52,7 @@ from tortoise.queryset import BulkUpdateQuery, ExistsQuery, Q, QuerySet, QuerySetSingle, RawSQLQuery from tortoise.router import router from tortoise.signals import Signals -from tortoise.transactions import current_transaction_map, in_transaction +from tortoise.transactions import in_transaction MODEL = TypeVar("MODEL", bound="Model") EMPTY = object() @@ -271,6 +271,10 @@ def add_field(self, name: str, value: Field) -> None: @property def db(self) -> BaseDBAsyncClient: + if self.default_connection is None: + raise ConfigurationError( + f"default_connection for the model {self._model} cannot be None" + ) return connections.get(self.default_connection) @property diff --git a/tortoise/router.py b/tortoise/router.py index f2e14d3a9..7ce5f8503 100644 --- a/tortoise/router.py +++ b/tortoise/router.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, List, Optional, Type -from tortoise.exceptions import ConfigurationError from tortoise.connection import connections +from tortoise.exceptions import ConfigurationError if TYPE_CHECKING: from tortoise import BaseDBAsyncClient, Model From faa4d87e10180367fb69e0049375d6fae672fcff Mon Sep 17 00:00:00 2001 From: Aditya N Date: Mon, 6 Dec 2021 01:00:14 +0530 Subject: [PATCH 04/25] - Modified comment --- tortoise/contrib/test/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index 15599f211..ff7daec04 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -403,10 +403,10 @@ async def _run_outcome(self, outcome, expecting_failure, testMethod) -> None: await super()._run_outcome(outcome, expecting_failure, testMethod) # Clearing out the storage so as to provide a clean storage state - # for all tests. Have to explicitly clear it here since teardown - # gets called as part of _run_outcome which would only clear the nested - # storage in the contextvar. This is not needed for other testcase classes - # as they don't run inside a transaction + # for all tests. Have to explicitly clear it here since the teardown which + # gets called in _run_outcome would only clear the nested + # storage created in the contextvar due to TransactionTestContext. + # This is not needed for other testcase classes as they don't run inside a transaction self._reset_conn_state() async def _setUpDB(self) -> None: From 79b187a7e5a82ddd839b9e896fd492df00a59d67 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Tue, 7 Dec 2021 00:44:34 +0530 Subject: [PATCH 05/25] - Code cleanup --- tortoise/contrib/test/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index ff7daec04..71619a20e 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -394,7 +394,7 @@ class TestCase(TruncationTestCase): async def _run_outcome(self, outcome, expecting_failure, testMethod) -> None: _restore_default() - self.__db__ = Tortoise.get_connection("models") + self.__db__ = connections.get("models") if self.__db__.capabilities.supports_transactions: connection = self.__db__._in_transaction().connection async with TransactionTestContext(connection): From 8184432078a33b501882c6ad664f16b58647b2e4 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Tue, 7 Dec 2021 14:25:47 +0530 Subject: [PATCH 06/25] - Fixed style issues --- tortoise/connection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tortoise/connection.py b/tortoise/connection.py index 979bcb697..a74488fbf 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from tortoise.backends.base.client import BaseDBAsyncClient + DBConfigType = Dict[str, Any] @@ -22,7 +23,7 @@ def __init__(self) -> None: self._db_config: Optional["DBConfigType"] = None self._create_db: bool = False - async def _init(self, db_config: Dict[str, Any], create_db: bool): + async def _init(self, db_config: "DBConfigType", create_db: bool): self._db_config = db_config self._create_db = create_db await self._init_connections() From d3e7403c1e6896f4f85adc153c1b6e7321e0af52 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Sun, 19 Dec 2021 19:32:56 +0530 Subject: [PATCH 07/25] - Fixed mysql breaking tests --- tortoise/contrib/test/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index 71619a20e..e559449a5 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -11,7 +11,7 @@ from tortoise import Model, Tortoise, connections from tortoise.backends.base.config_generator import generate_config as _generate_config -from tortoise.exceptions import DBConnectionError +from tortoise.exceptions import DBConnectionError, OperationalError __all__ = ( "SimpleTestCase", @@ -62,10 +62,12 @@ def getDBConfig(app_label: str, modules: Iterable[Union[str, ModuleType]]) -> di async def _init_db(config: dict) -> None: + # Placing init outside the try block since it doesn't + # establish connections to the DB eagerly. + await Tortoise.init(config) try: - await Tortoise.init(config) await Tortoise._drop_databases() - except DBConnectionError: # pragma: nocoverage + except (DBConnectionError, OperationalError): # pragma: nocoverage pass await Tortoise.init(config, _create_db=True) From 18912efd3bab3d3e6cc52db58f36147327e1ff80 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Thu, 13 Jan 2022 16:37:59 +0530 Subject: [PATCH 08/25] - Merged latest head from develop and fixed breaking tests --- tests/backends/test_mysql.py | 5 +++-- tests/backends/test_postgres.py | 3 ++- .../test_bad_relation_reference.py | 3 +-- tests/model_setup/test_init.py | 6 +++++- tests/schema/test_generate_schema.py | 1 + tests/test_concurrency.py | 1 + tortoise/__init__.py | 6 +----- tortoise/backends/base/client.py | 1 + tortoise/connection.py | 2 +- tortoise/contrib/test/__init__.py | 21 +++---------------- 10 files changed, 19 insertions(+), 30 deletions(-) diff --git a/tests/backends/test_mysql.py b/tests/backends/test_mysql.py index cb5babc02..4f5b427df 100644 --- a/tests/backends/test_mysql.py +++ b/tests/backends/test_mysql.py @@ -18,16 +18,17 @@ async def asyncSetUp(self): async def asyncTearDown(self) -> None: if Tortoise._inited: await Tortoise._drop_databases() + await super().asyncTearDown() async def test_bad_charset(self): self.db_config["connections"]["models"]["credentials"]["charset"] = "terrible" with self.assertRaisesRegex(ConnectionError, "Unknown charset"): - await Tortoise.init(self.db_config) + await Tortoise.init(self.db_config, _create_db=True) async def test_ssl_true(self): self.db_config["connections"]["models"]["credentials"]["ssl"] = True try: - await Tortoise.init(self.db_config) + await Tortoise.init(self.db_config, _create_db=True) except (ConnectionError, ssl.SSLError): pass else: diff --git a/tests/backends/test_postgres.py b/tests/backends/test_postgres.py index 0551b3141..055e2cefb 100644 --- a/tests/backends/test_postgres.py +++ b/tests/backends/test_postgres.py @@ -21,6 +21,7 @@ async def asyncSetUp(self): async def asyncTearDown(self) -> None: if Tortoise._inited: await Tortoise._drop_databases() + await super().asyncTearDown() async def test_schema(self): from asyncpg.exceptions import InvalidSchemaNameError @@ -56,7 +57,7 @@ async def test_schema(self): async def test_ssl_true(self): self.db_config["connections"]["models"]["credentials"]["ssl"] = True try: - await Tortoise.init(self.db_config) + await Tortoise.init(self.db_config, _create_db=True) except (ConnectionError, ssl.SSLError): pass else: diff --git a/tests/model_setup/test_bad_relation_reference.py b/tests/model_setup/test_bad_relation_reference.py index 5d7c25d9a..7fcf721e3 100644 --- a/tests/model_setup/test_bad_relation_reference.py +++ b/tests/model_setup/test_bad_relation_reference.py @@ -5,16 +5,15 @@ class TestBadReleationReferenceErrors(test.SimpleTestCase): async def asyncSetUp(self): + await super().asyncSetUp() try: Tortoise.apps = {} - Tortoise._connections = {} Tortoise._inited = False except ConfigurationError: pass Tortoise._inited = False async def asyncTearDown(self) -> None: - await Tortoise.close_connections() await Tortoise._reset_apps() await super(TestBadReleationReferenceErrors, self).asyncTearDown() diff --git a/tests/model_setup/test_init.py b/tests/model_setup/test_init.py index 3b9f74ddd..bd8a32e36 100644 --- a/tests/model_setup/test_init.py +++ b/tests/model_setup/test_init.py @@ -182,7 +182,11 @@ async def test_nonpk_id(self): ) async def test_unknown_connection(self): - with self.assertRaisesRegex(ConfigurationError, 'Unknown connection "fioop"'): + with self.assertRaisesRegex( + ConfigurationError, + "Unable to get db settings for alias 'fioop'. Please " + "check if the config dict contains this alias and try again" + ): await Tortoise.init( { "connections": { diff --git a/tests/schema/test_generate_schema.py b/tests/schema/test_generate_schema.py index 10dc68fa7..a96293c58 100644 --- a/tests/schema/test_generate_schema.py +++ b/tests/schema/test_generate_schema.py @@ -25,6 +25,7 @@ async def asyncSetUp(self): async def asyncTearDown(self) -> None: await Tortoise._reset_apps() + await super().asyncTearDown() async def init_for(self, module: str, safe=False) -> None: with patch( diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 405491ee4..590463679 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -2,6 +2,7 @@ import sys from tests.testmodels import Tournament, UniqueName +from tortoise import connections from tortoise.contrib import test from tortoise.transactions import in_transaction diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 5e3ad3fe1..9568c7cc8 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -630,11 +630,7 @@ async def close_connections(cls) -> None: else your event loop may never complete as it is waiting for the connections to die. """ - tasks = [] - for connection in cls._connections.values(): - tasks.append(connection.close()) - await asyncio.gather(*tasks) - cls._connections = {} + await connections.close_all(discard=True) logger.info("Tortoise-ORM shutdown") @classmethod diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index 213598d18..f69bb1def 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -280,6 +280,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.connection.commit() if self.connection._parent._pool: await self.connection._parent._pool.release(self.connection._connection) + connections.reset(self.token) class NestedTransactionContext(TransactionContext): diff --git a/tortoise/connection.py b/tortoise/connection.py index a74488fbf..b519779da 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -62,7 +62,7 @@ def _get_db_info(self, conn_alias: str) -> Union[str, Dict]: return self.db_config[conn_alias] except KeyError: raise ConfigurationError( - f"Unable to get db settings for alias {conn_alias}. Please " + f"Unable to get db settings for alias '{conn_alias}'. Please " f"check if the config dict contains this alias and try again" ) diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index f1baaf364..6c44528db 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -168,11 +168,7 @@ class SimpleTestCase(unittest.IsolatedAsyncioTestCase): """ async def _setUpDB(self) -> None: - # setting storage to an empty dict explicitly to create a - # ContextVar specific scope for each test case. The storage will - # either be restored from _restore_default or re-initialised depending on - # how the test case should behave - self.token = connections._set_storage({}) + pass async def _tearDownDB(self) -> None: pass @@ -195,12 +191,9 @@ async def asyncSetUp(self) -> None: await self._setUpDB() def _reset_conn_state(self) -> None: - # clearing the storage and restoring to previous storage state - # of the contextvar + # clearing the storage and db config connections._clear_storage() connections.db_config.clear() - if self.token: - connections.reset(self.token) async def asyncTearDown(self) -> None: await self._tearDownDB() @@ -317,7 +310,6 @@ class TestCase(TruncationTestCase): async def asyncSetUp(self) -> None: await super(TestCase, self).asyncSetUp() - _restore_default() self.__db__ = connections.get("models") self.__transaction__ = TransactionTestContext(self.__db__._in_transaction().connection) await self.__transaction__.__aenter__() # type: ignore @@ -326,13 +318,6 @@ async def asyncTearDown(self) -> None: await self.__transaction__.__aexit__(None, None, None) await super(TestCase, self).asyncTearDown() - # Clearing out the storage so as to provide a clean storage state - # for all tests. Have to explicitly clear it here since the teardown which - # gets called in _run_outcome would only clear the nested - # storage created in the contextvar due to TransactionTestContext. - # This is not needed for other testcase classes as they don't run inside a transaction - self._reset_conn_state() - async def _setUpDB(self) -> None: await super(TestCase, self)._setUpDB() @@ -375,7 +360,7 @@ def decorator(test_item): @wraps(test_item) def skip_wrapper(*args, **kwargs): - db = Tortoise.get_connection(connection_name) + db = connections.get(connection_name) for key, val in conditions.items(): if getattr(db.capabilities, key) != val: raise SkipTest(f"Capability {key} != {val}") From 9ed0e713dfcde0c91a139d96a190c8e23377146a Mon Sep 17 00:00:00 2001 From: Aditya N Date: Thu, 13 Jan 2022 16:51:03 +0530 Subject: [PATCH 09/25] - Fixed style and lint issues --- tests/model_setup/test_init.py | 6 +++--- tests/test_concurrency.py | 1 - tortoise/connection.py | 3 +-- tortoise/contrib/test/__init__.py | 1 - 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/model_setup/test_init.py b/tests/model_setup/test_init.py index bd8a32e36..26d4a8682 100644 --- a/tests/model_setup/test_init.py +++ b/tests/model_setup/test_init.py @@ -183,9 +183,9 @@ async def test_nonpk_id(self): async def test_unknown_connection(self): with self.assertRaisesRegex( - ConfigurationError, - "Unable to get db settings for alias 'fioop'. Please " - "check if the config dict contains this alias and try again" + ConfigurationError, + "Unable to get db settings for alias 'fioop'. Please " + "check if the config dict contains this alias and try again", ): await Tortoise.init( { diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 590463679..405491ee4 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -2,7 +2,6 @@ import sys from tests.testmodels import Tournament, UniqueName -from tortoise import connections from tortoise.contrib import test from tortoise.transactions import in_transaction diff --git a/tortoise/connection.py b/tortoise/connection.py index b519779da..a19e26a72 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -50,9 +50,8 @@ def _clear_storage(self) -> None: def _discover_client_class(self, engine: str) -> Type["BaseDBAsyncClient"]: # Let exception bubble up for transparency engine_module = importlib.import_module(engine) - try: - client_class = engine_module.client_class # type: ignore + client_class = engine_module.client_class except AttributeError: raise ConfigurationError(f'Backend for engine "{engine}" does not implement db client') return client_class diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index 6c44528db..c7095f7c2 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -1,5 +1,4 @@ import asyncio -import contextvars import os as _os import unittest from asyncio.events import AbstractEventLoop From 69d947eb00c8879797ccebf52a61fdfb38902ed2 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Thu, 13 Jan 2022 17:40:13 +0530 Subject: [PATCH 10/25] - Fixed code quality issues (codacy) --- tortoise/backends/base/client.py | 2 ++ tortoise/connection.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index f69bb1def..5972272e6 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -206,6 +206,7 @@ class ConnectionWrapper: __slots__ = ("connection", "lock", "client") def __init__(self, lock: asyncio.Lock, client: Any) -> None: + """Wraps the connections with a lock to facilitate safe concurrent access.""" self.lock: asyncio.Lock = lock self.client = client self.connection: Any = client._connection @@ -311,6 +312,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class PoolConnectionWrapper: def __init__(self, client: Any) -> None: + """Class to manage acquiring from and releasing connections to a pool.""" self.pool = client._pool self.client = client self.connection = None diff --git a/tortoise/connection.py b/tortoise/connection.py index a19e26a72..972f46391 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -20,6 +20,7 @@ class ConnectionHandler: ) def __init__(self) -> None: + """Unified connection management interface.""" self._db_config: Optional["DBConfigType"] = None self._create_db: bool = False From 3b3b2e58fc41550e061db50de2685dec4473cf45 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Thu, 13 Jan 2022 22:35:57 +0530 Subject: [PATCH 11/25] - Refactored core and test files to use the new connection interface --- examples/manual_sql.py | 4 ++-- examples/schema_create.py | 8 +++---- examples/two_databases.py | 6 ++--- tests/backends/test_capabilities.py | 4 ++-- tests/backends/test_connection_params.py | 18 +++++++------- tests/backends/test_postgres.py | 8 +++---- tests/backends/test_reconnect.py | 10 ++++---- tests/contrib/test_functions.py | 4 ++-- tests/schema/test_generate_schema.py | 30 ++++++++++++------------ tests/test_case_when.py | 4 ++-- tests/test_manual_sql.py | 12 +++++----- tests/test_queryset.py | 4 ++-- tortoise/__init__.py | 23 ++++-------------- tortoise/contrib/aiohttp/__init__.py | 4 ++-- tortoise/contrib/blacksheep/__init__.py | 4 ++-- tortoise/contrib/fastapi/__init__.py | 4 ++-- tortoise/contrib/quart/__init__.py | 4 ++-- tortoise/contrib/sanic/__init__.py | 4 ++-- tortoise/contrib/starlette/__init__.py | 4 ++-- 19 files changed, 74 insertions(+), 85 deletions(-) diff --git a/examples/manual_sql.py b/examples/manual_sql.py index 0205eed03..eab58d134 100644 --- a/examples/manual_sql.py +++ b/examples/manual_sql.py @@ -1,7 +1,7 @@ """ This example demonstrates executing manual SQL queries """ -from tortoise import Tortoise, fields, run_async +from tortoise import Tortoise, fields, run_async, connections from tortoise.models import Model from tortoise.transactions import in_transaction @@ -17,7 +17,7 @@ async def run(): await Tortoise.generate_schemas() # Need to get a connection. Unless explicitly specified, the name should be 'default' - conn = Tortoise.get_connection("default") + conn = connections.get("default") # Now we can execute queries in the normal autocommit mode await conn.execute_query("INSERT INTO event (name) VALUES ('Foo')") diff --git a/examples/schema_create.py b/examples/schema_create.py index 96f7f91c7..86f112a86 100644 --- a/examples/schema_create.py +++ b/examples/schema_create.py @@ -1,7 +1,7 @@ """ This example demonstrates SQL Schema generation for each DB type supported. """ -from tortoise import Tortoise, fields, run_async +from tortoise import Tortoise, fields, run_async, connections from tortoise.models import Model from tortoise.utils import get_schema_sql @@ -49,19 +49,19 @@ class Meta: async def run(): print("SQLite:\n") await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]}) - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) print(sql) print("\n\nMySQL:\n") await Tortoise.init(db_url="mysql://root:@127.0.0.1:3306/", modules={"models": ["__main__"]}) - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) print(sql) print("\n\nPostgreSQL:\n") await Tortoise.init( db_url="postgres://postgres:@127.0.0.1:5432/", modules={"models": ["__main__"]} ) - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) print(sql) diff --git a/examples/two_databases.py b/examples/two_databases.py index 88b8ed8fe..633c3a961 100644 --- a/examples/two_databases.py +++ b/examples/two_databases.py @@ -8,7 +8,7 @@ Key notes of this example is using db_route for Tortoise init and explicitly declaring model apps in class Meta """ -from tortoise import Tortoise, fields, run_async +from tortoise import Tortoise, fields, run_async, connections from tortoise.exceptions import OperationalError from tortoise.models import Model @@ -73,8 +73,8 @@ async def run(): } ) await Tortoise.generate_schemas() - client = Tortoise.get_connection("first") - second_client = Tortoise.get_connection("second") + client = connections.get("first") + second_client = connections.get("second") tournament = await Tournament.create(name="Tournament") await Event(name="Event", tournament_id=tournament.id).save() diff --git a/tests/backends/test_capabilities.py b/tests/backends/test_capabilities.py index 5acb4a9c9..54beaaa7a 100644 --- a/tests/backends/test_capabilities.py +++ b/tests/backends/test_capabilities.py @@ -1,4 +1,4 @@ -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test @@ -7,7 +7,7 @@ class TestCapabilities(test.TestCase): async def asyncSetUp(self) -> None: await super(TestCapabilities, self).asyncSetUp() - self.db = Tortoise.get_connection("models") + self.db = connections.get("models") self.caps = self.db.capabilities def test_str(self): diff --git a/tests/backends/test_connection_params.py b/tests/backends/test_connection_params.py index 1f8fe4386..f5375e9f3 100644 --- a/tests/backends/test_connection_params.py +++ b/tests/backends/test_connection_params.py @@ -2,20 +2,20 @@ import asyncpg -from tortoise import Tortoise +from tortoise import connections from tortoise.contrib import test -class TestConnectionParams(test.TestCase): +class TestConnectionParams(test.SimpleTestCase): async def asyncSetUp(self) -> None: - pass + await super().asyncSetUp() async def asyncTearDown(self) -> None: - pass + await super().asyncTearDown() async def test_mysql_connection_params(self): with patch("asyncmy.create_pool", new=AsyncMock()) as mysql_connect: - await Tortoise._init_connections( + await connections._init( { "models": { "engine": "tortoise.backends.mysql", @@ -30,8 +30,9 @@ async def test_mysql_connection_params(self): }, } }, - False, + False ) + await connections.get("models").create_connection(with_db=True) mysql_connect.assert_awaited_once_with( # nosec autocommit=True, @@ -50,7 +51,7 @@ async def test_mysql_connection_params(self): async def test_postres_connection_params(self): try: with patch("asyncpg.create_pool", new=AsyncMock()) as asyncpg_connect: - await Tortoise._init_connections( + await connections._init( { "models": { "engine": "tortoise.backends.asyncpg", @@ -65,8 +66,9 @@ async def test_postres_connection_params(self): }, } }, - False, + False ) + await connections.get("models").create_connection(with_db=True) asyncpg_connect.assert_awaited_once_with( # nosec None, diff --git a/tests/backends/test_postgres.py b/tests/backends/test_postgres.py index 055e2cefb..19fa2543b 100644 --- a/tests/backends/test_postgres.py +++ b/tests/backends/test_postgres.py @@ -4,7 +4,7 @@ import ssl from tests.testmodels import Tournament -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.exceptions import OperationalError @@ -32,7 +32,7 @@ async def test_schema(self): with self.assertRaises(InvalidSchemaNameError): await Tortoise.generate_schemas() - conn = Tortoise.get_connection("models") + conn = connections.get("models") await conn.execute_script("CREATE SCHEMA mytestschema;") await Tortoise.generate_schemas() @@ -45,7 +45,7 @@ async def test_schema(self): with self.assertRaises(OperationalError): await Tournament.filter(name="Test").first() - conn = Tortoise.get_connection("models") + conn = connections.get("models") _, res = await conn.execute_query( "SELECT id, name FROM mytestschema.tournament WHERE name='Test' LIMIT 1" ) @@ -81,7 +81,7 @@ async def test_application_name(self): ] = "mytest_application" await Tortoise.init(self.db_config, _create_db=True) - conn = Tortoise.get_connection("models") + conn = connections.get("models") _, res = await conn.execute_query( "SELECT application_name FROM pg_stat_activity WHERE pid = pg_backend_pid()" ) diff --git a/tests/backends/test_reconnect.py b/tests/backends/test_reconnect.py index 75b2eb04a..93766ccb3 100644 --- a/tests/backends/test_reconnect.py +++ b/tests/backends/test_reconnect.py @@ -1,5 +1,5 @@ from tests.testmodels import Tournament -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.transactions import in_transaction @@ -9,11 +9,11 @@ class TestReconnect(test.IsolatedTestCase): async def test_reconnect(self): await Tournament.create(name="1") - await Tortoise._connections["models"]._expire_connections() + await connections.get("models")._expire_connections() await Tournament.create(name="2") - await Tortoise._connections["models"]._expire_connections() + await connections.get("models")._expire_connections() await Tournament.create(name="3") @@ -26,12 +26,12 @@ async def test_reconnect_transaction_start(self): async with in_transaction(): await Tournament.create(name="1") - await Tortoise._connections["models"]._expire_connections() + await connections.get("models")._expire_connections() async with in_transaction(): await Tournament.create(name="2") - await Tortoise._connections["models"]._expire_connections() + await connections.get("models")._expire_connections() async with in_transaction(): self.assertEqual([f"{a.id}:{a.name}" for a in await Tournament.all()], ["1:1", "2:2"]) diff --git a/tests/contrib/test_functions.py b/tests/contrib/test_functions.py index 5abf398f1..c11fad9cb 100644 --- a/tests/contrib/test_functions.py +++ b/tests/contrib/test_functions.py @@ -1,5 +1,5 @@ from tests.testmodels import IntFields -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.contrib.mysql.functions import Rand from tortoise.contrib.postgres.functions import Random as PostgresRandom @@ -10,7 +10,7 @@ class TestFunction(test.TestCase): async def asyncSetUp(self): await super().asyncSetUp() self.intfields = [await IntFields.create(intnum=val) for val in range(10)] - self.db = Tortoise.get_connection("models") + self.db = connections.get("models") @test.requireCapability(dialect="mysql") async def test_mysql_func_rand(self): diff --git a/tests/schema/test_generate_schema.py b/tests/schema/test_generate_schema.py index a96293c58..58e4764fc 100644 --- a/tests/schema/test_generate_schema.py +++ b/tests/schema/test_generate_schema.py @@ -137,7 +137,7 @@ async def test_table_and_row_comment_generation(self): async def test_schema_no_db_constraint(self): self.maxDiff = None await self.init_for("tests.schema.models_no_db_constraint") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE "team" ( @@ -177,7 +177,7 @@ async def test_schema_no_db_constraint(self): async def test_schema(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), """ @@ -264,7 +264,7 @@ async def test_schema(self): async def test_schema_safe(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=True) + sql = get_schema_sql(connections.get("default"), safe=True) self.assertEqual( sql.strip(), """ @@ -351,7 +351,7 @@ async def test_schema_safe(self): async def test_m2m_no_auto_create(self): self.maxDiff = None await self.init_for("tests.schema.models_no_auto_create_m2m") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE "team" ( @@ -458,7 +458,7 @@ async def test_table_and_row_comment_generation(self): async def test_schema_no_db_constraint(self): self.maxDiff = None await self.init_for("tests.schema.models_no_db_constraint") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE `team` ( @@ -597,7 +597,7 @@ async def test_schema(self): async def test_schema_safe(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=True) + sql = get_schema_sql(connections.get("default"), safe=True) self.assertEqual( sql.strip(), @@ -696,7 +696,7 @@ async def test_schema_safe(self): async def test_index_safe(self): await self.init_for("tests.schema.models_mysql_index") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=True) + sql = get_schema_sql(connections.get("default"), safe=True) self.assertEqual( sql, """CREATE TABLE IF NOT EXISTS `index` ( @@ -710,7 +710,7 @@ async def test_index_safe(self): async def test_index_unsafe(self): await self.init_for("tests.schema.models_mysql_index") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql, """CREATE TABLE `index` ( @@ -725,7 +725,7 @@ async def test_index_unsafe(self): async def test_m2m_no_auto_create(self): self.maxDiff = None await self.init_for("tests.schema.models_no_auto_create_m2m") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE `team` ( @@ -819,7 +819,7 @@ async def test_table_and_row_comment_generation(self): async def test_schema_no_db_constraint(self): self.maxDiff = None await self.init_for("tests.schema.models_no_db_constraint") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE "team" ( @@ -869,7 +869,7 @@ async def test_schema_no_db_constraint(self): async def test_schema(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), """ @@ -971,7 +971,7 @@ async def test_schema(self): async def test_schema_safe(self): self.maxDiff = None await self.init_for("tests.schema.models_schema_create") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=True) + sql = get_schema_sql(connections.get("default"), safe=True) self.assertEqual( sql.strip(), """ @@ -1072,7 +1072,7 @@ async def test_schema_safe(self): async def test_index_unsafe(self): await self.init_for("tests.schema.models_postgres_index") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql, """CREATE TABLE "index" ( @@ -1094,7 +1094,7 @@ async def test_index_unsafe(self): async def test_index_safe(self): await self.init_for("tests.schema.models_postgres_index") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=True) + sql = get_schema_sql(connections.get("default"), safe=True) self.assertEqual( sql, """CREATE TABLE IF NOT EXISTS "index" ( @@ -1117,7 +1117,7 @@ async def test_index_safe(self): async def test_m2m_no_auto_create(self): self.maxDiff = None await self.init_for("tests.schema.models_no_auto_create_m2m") - sql = get_schema_sql(Tortoise.get_connection("default"), safe=False) + sql = get_schema_sql(connections.get("default"), safe=False) self.assertEqual( sql.strip(), r"""CREATE TABLE "team" ( diff --git a/tests/test_case_when.py b/tests/test_case_when.py index 2ec15b97b..59a9a93f3 100644 --- a/tests/test_case_when.py +++ b/tests/test_case_when.py @@ -1,5 +1,5 @@ from tests.testmodels import IntFields -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.expressions import Case, F, Q, When from tortoise.functions import Coalesce @@ -9,7 +9,7 @@ class TestCaseWhen(test.TestCase): async def asyncSetUp(self): await super().asyncSetUp() self.intfields = [await IntFields.create(intnum=val) for val in range(10)] - self.db = Tortoise.get_connection("models") + self.db = connections.get("models") async def test_single_when(self): category = Case(When(intnum__gte=8, then="big"), default="default") diff --git a/tests/test_manual_sql.py b/tests/test_manual_sql.py index eeac16078..1b8f92a99 100644 --- a/tests/test_manual_sql.py +++ b/tests/test_manual_sql.py @@ -1,11 +1,11 @@ -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.transactions import in_transaction class TestManualSQL(test.TruncationTestCase): async def test_simple_insert(self): - conn = Tortoise.get_connection("models") + conn = connections.get("models") await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") self.assertEqual( await conn.execute_query_dict("SELECT name FROM author"), [{"name": "Foo"}] @@ -15,7 +15,7 @@ async def test_in_transaction(self): async with in_transaction() as conn: await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") - conn = Tortoise.get_connection("models") + conn = connections.get("models") self.assertEqual( await conn.execute_query_dict("SELECT name FROM author"), [{"name": "Foo"}] ) @@ -29,7 +29,7 @@ async def test_in_transaction_exception(self): except ValueError: pass - conn = Tortoise.get_connection("models") + conn = connections.get("models") self.assertEqual(await conn.execute_query_dict("SELECT name FROM author"), []) @test.requireCapability(supports_transactions=True) @@ -38,7 +38,7 @@ async def test_in_transaction_rollback(self): await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") await conn.rollback() - conn = Tortoise.get_connection("models") + conn = connections.get("models") self.assertEqual(await conn.execute_query_dict("SELECT name FROM author"), []) async def test_in_transaction_commit(self): @@ -50,7 +50,7 @@ async def test_in_transaction_commit(self): except ValueError: pass - conn = Tortoise.get_connection("models") + conn = connections.get("models") self.assertEqual( await conn.execute_query_dict("SELECT name FROM author"), [{"name": "Foo"}] ) diff --git a/tests/test_queryset.py b/tests/test_queryset.py index 66a58d0ae..5d56e7584 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -1,5 +1,5 @@ from tests.testmodels import Event, IntFields, MinRelation, Node, Reporter, Tournament, Tree -from tortoise import Tortoise +from tortoise import connections from tortoise.contrib import test from tortoise.exceptions import ( DoesNotExist, @@ -19,7 +19,7 @@ async def asyncSetUp(self): await super().asyncSetUp() # Build large dataset self.intfields = [await IntFields.create(intnum=val) for val in range(10, 100, 3)] - self.db = Tortoise.get_connection("models") + self.db = connections.get("models") async def test_all_count(self): self.assertEqual(await IntFields.all().count(), 30) diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 9568c7cc8..9fedec309 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -3,7 +3,6 @@ import json import os import warnings -from contextvars import ContextVar from copy import deepcopy from inspect import isclass from types import ModuleType @@ -31,7 +30,6 @@ class Tortoise: apps: Dict[str, Dict[str, Type["Model"]]] = {} - _connections: Dict[str, BaseDBAsyncClient] = {} _inited: bool = False @classmethod @@ -40,6 +38,10 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient: Returns the connection by name. :raises KeyError: If connection name does not exist. + + .. warning:: + This is deprecated and will be removed in a future release. Please use + :meth:`tortoise.connection.connections.get` instead. """ return connections.get(connection_name) @@ -377,21 +379,6 @@ def _discover_models( warnings.warn(f'Module "{models_path}" has no models', RuntimeWarning, stacklevel=4) return discovered_models - @classmethod - async def _init_connections(cls, connections_config: dict, create_db: bool) -> None: - for name, info in connections_config.items(): - if isinstance(info, str): - info = expand_db_url(info) - client_class = cls._discover_client_class(info.get("engine")) - db_params = info["credentials"].copy() - db_params.update({"connection_name": name}) - connection = client_class(**db_params) - if create_db: - await connection.db_create() - await connection.create_connection(with_db=True) - cls._connections[name] = connection - current_transaction_map[name] = ContextVar(name, default=connection) - @classmethod def init_models( cls, @@ -424,7 +411,7 @@ def init_models( def _init_apps(cls, apps_config: dict) -> None: for name, info in apps_config.items(): try: - cls.get_connection(info.get("default_connection", "default")) + connections.get(info.get("default_connection", "default")) except KeyError: raise ConfigurationError( 'Unknown connection "{}" for app "{}"'.format( diff --git a/tortoise/contrib/aiohttp/__init__.py b/tortoise/contrib/aiohttp/__init__.py index 085bcb966..4781ec91b 100644 --- a/tortoise/contrib/aiohttp/__init__.py +++ b/tortoise/contrib/aiohttp/__init__.py @@ -3,7 +3,7 @@ from aiohttp import web # pylint: disable=E0401 -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.log import logger @@ -79,7 +79,7 @@ def register_tortoise( async def init_orm(app): # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info(f"Tortoise-ORM started, {Tortoise._connections}, {Tortoise.apps}") + logger.info(f"Tortoise-ORM started, {connections._get_storage()}, {Tortoise.apps}") if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() diff --git a/tortoise/contrib/blacksheep/__init__.py b/tortoise/contrib/blacksheep/__init__.py index bcde53cb7..32496eb2b 100644 --- a/tortoise/contrib/blacksheep/__init__.py +++ b/tortoise/contrib/blacksheep/__init__.py @@ -5,7 +5,7 @@ from blacksheep.server import Application from blacksheep.server.responses import json -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.exceptions import DoesNotExist, IntegrityError from tortoise.log import logger @@ -87,7 +87,7 @@ def register_tortoise( @app.on_start async def init_orm(context) -> None: # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", Tortoise._connections, Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() diff --git a/tortoise/contrib/fastapi/__init__.py b/tortoise/contrib/fastapi/__init__.py index 45a5f3ea9..83451f0a2 100644 --- a/tortoise/contrib/fastapi/__init__.py +++ b/tortoise/contrib/fastapi/__init__.py @@ -5,7 +5,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel # pylint: disable=E0611 -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.exceptions import DoesNotExist, IntegrityError from tortoise.log import logger @@ -91,7 +91,7 @@ def register_tortoise( @app.on_event("startup") async def init_orm() -> None: # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", Tortoise._connections, Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() diff --git a/tortoise/contrib/quart/__init__.py b/tortoise/contrib/quart/__init__.py index e13c25053..14ea890a1 100644 --- a/tortoise/contrib/quart/__init__.py +++ b/tortoise/contrib/quart/__init__.py @@ -5,7 +5,7 @@ from quart import Quart # pylint: disable=E0401 -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.log import logger @@ -84,7 +84,7 @@ def register_tortoise( @app.before_serving async def init_orm() -> None: # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", Tortoise._connections, Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) if _generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() diff --git a/tortoise/contrib/sanic/__init__.py b/tortoise/contrib/sanic/__init__.py index c1fd8e526..5cb2f7c15 100644 --- a/tortoise/contrib/sanic/__init__.py +++ b/tortoise/contrib/sanic/__init__.py @@ -3,7 +3,7 @@ from sanic import Sanic # pylint: disable=E0401 -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.log import logger @@ -80,7 +80,7 @@ def register_tortoise( @app.listener("before_server_start") # type:ignore async def init_orm(app, loop): # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", Tortoise._connections, Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() diff --git a/tortoise/contrib/starlette/__init__.py b/tortoise/contrib/starlette/__init__.py index d5952155e..38c209cdf 100644 --- a/tortoise/contrib/starlette/__init__.py +++ b/tortoise/contrib/starlette/__init__.py @@ -3,7 +3,7 @@ from starlette.applications import Starlette # pylint: disable=E0401 -from tortoise import Tortoise +from tortoise import Tortoise, connections from tortoise.log import logger @@ -80,7 +80,7 @@ def register_tortoise( @app.on_event("startup") async def init_orm() -> None: # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", Tortoise._connections, Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() From a7f47ba7e62941507720b9073dcea0721e51b66d Mon Sep 17 00:00:00 2001 From: Aditya N Date: Thu, 13 Jan 2022 22:40:13 +0530 Subject: [PATCH 12/25] - Fixed style and lint errors --- examples/manual_sql.py | 2 +- examples/schema_create.py | 2 +- examples/two_databases.py | 2 +- tests/backends/test_capabilities.py | 2 +- tests/backends/test_connection_params.py | 4 ++-- tests/backends/test_reconnect.py | 2 +- tests/contrib/test_functions.py | 2 +- tests/test_case_when.py | 2 +- tests/test_manual_sql.py | 2 +- 9 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/manual_sql.py b/examples/manual_sql.py index eab58d134..b3d46e036 100644 --- a/examples/manual_sql.py +++ b/examples/manual_sql.py @@ -1,7 +1,7 @@ """ This example demonstrates executing manual SQL queries """ -from tortoise import Tortoise, fields, run_async, connections +from tortoise import Tortoise, connections, fields, run_async from tortoise.models import Model from tortoise.transactions import in_transaction diff --git a/examples/schema_create.py b/examples/schema_create.py index 86f112a86..4dcb47687 100644 --- a/examples/schema_create.py +++ b/examples/schema_create.py @@ -1,7 +1,7 @@ """ This example demonstrates SQL Schema generation for each DB type supported. """ -from tortoise import Tortoise, fields, run_async, connections +from tortoise import Tortoise, connections, fields, run_async from tortoise.models import Model from tortoise.utils import get_schema_sql diff --git a/examples/two_databases.py b/examples/two_databases.py index 633c3a961..96d59d68b 100644 --- a/examples/two_databases.py +++ b/examples/two_databases.py @@ -8,7 +8,7 @@ Key notes of this example is using db_route for Tortoise init and explicitly declaring model apps in class Meta """ -from tortoise import Tortoise, fields, run_async, connections +from tortoise import Tortoise, connections, fields, run_async from tortoise.exceptions import OperationalError from tortoise.models import Model diff --git a/tests/backends/test_capabilities.py b/tests/backends/test_capabilities.py index 54beaaa7a..1a9c1be21 100644 --- a/tests/backends/test_capabilities.py +++ b/tests/backends/test_capabilities.py @@ -1,4 +1,4 @@ -from tortoise import Tortoise, connections +from tortoise import connections from tortoise.contrib import test diff --git a/tests/backends/test_connection_params.py b/tests/backends/test_connection_params.py index f5375e9f3..f713817d2 100644 --- a/tests/backends/test_connection_params.py +++ b/tests/backends/test_connection_params.py @@ -30,7 +30,7 @@ async def test_mysql_connection_params(self): }, } }, - False + False, ) await connections.get("models").create_connection(with_db=True) @@ -66,7 +66,7 @@ async def test_postres_connection_params(self): }, } }, - False + False, ) await connections.get("models").create_connection(with_db=True) diff --git a/tests/backends/test_reconnect.py b/tests/backends/test_reconnect.py index 93766ccb3..b0e07ef6b 100644 --- a/tests/backends/test_reconnect.py +++ b/tests/backends/test_reconnect.py @@ -1,5 +1,5 @@ from tests.testmodels import Tournament -from tortoise import Tortoise, connections +from tortoise import connections from tortoise.contrib import test from tortoise.transactions import in_transaction diff --git a/tests/contrib/test_functions.py b/tests/contrib/test_functions.py index c11fad9cb..6f4e2c1ff 100644 --- a/tests/contrib/test_functions.py +++ b/tests/contrib/test_functions.py @@ -1,5 +1,5 @@ from tests.testmodels import IntFields -from tortoise import Tortoise, connections +from tortoise import connections from tortoise.contrib import test from tortoise.contrib.mysql.functions import Rand from tortoise.contrib.postgres.functions import Random as PostgresRandom diff --git a/tests/test_case_when.py b/tests/test_case_when.py index 59a9a93f3..fa041d7e4 100644 --- a/tests/test_case_when.py +++ b/tests/test_case_when.py @@ -1,5 +1,5 @@ from tests.testmodels import IntFields -from tortoise import Tortoise, connections +from tortoise import connections from tortoise.contrib import test from tortoise.expressions import Case, F, Q, When from tortoise.functions import Coalesce diff --git a/tests/test_manual_sql.py b/tests/test_manual_sql.py index 1b8f92a99..bfbf5cc95 100644 --- a/tests/test_manual_sql.py +++ b/tests/test_manual_sql.py @@ -1,4 +1,4 @@ -from tortoise import Tortoise, connections +from tortoise import connections from tortoise.contrib import test from tortoise.transactions import in_transaction From e8e529bdd02316d91c22e0d7ec6e9b57d36cb914 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Tue, 18 Jan 2022 19:11:00 +0530 Subject: [PATCH 13/25] - Added unit tests for the new connection interface --- tests/test_connection.py | 271 +++++++++++++++++++++++++++++++++++++++ tortoise/connection.py | 12 +- 2 files changed, 277 insertions(+), 6 deletions(-) create mode 100644 tests/test_connection.py diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 000000000..f0b718a6d --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,271 @@ +from contextvars import ContextVar +from unittest.mock import AsyncMock, Mock, PropertyMock, call, patch + +from tortoise import BaseDBAsyncClient, ConfigurationError +from tortoise.connection import ConnectionHandler +from tortoise.contrib.test import SimpleTestCase + + +class TestConnections(SimpleTestCase): + def setUp(self) -> None: + self.conn_handler = ConnectionHandler() + + def test_init_constructor(self): + self.assertIsNone(self.conn_handler._db_config) + self.assertFalse(self.conn_handler._create_db) + + @patch("tortoise.connection.ConnectionHandler._init_connections") + async def test_init(self, mocked_init_connections: AsyncMock): + db_config = {"default": {"HOST": "some_host", "PORT": "1234"}} + await self.conn_handler._init(db_config, True) + mocked_init_connections.assert_awaited_once() + self.assertEqual(db_config, self.conn_handler._db_config) + self.assertTrue(self.conn_handler._create_db) + + def test_db_config_present(self): + self.conn_handler._db_config = {"default": {"HOST": "some_host", "PORT": "1234"}} + self.assertEqual(self.conn_handler.db_config, self.conn_handler._db_config) + + def test_db_config_not_present(self): + err_msg = ( + "DB configuration not initialised. Make sure to call " + "Tortoise.init with a valid configuration before attempting " + "to create connections." + ) + with self.assertRaises(ConfigurationError, msg=err_msg): + _ = self.conn_handler.db_config + + @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) + def test_get_storage(self, mocked_conn_storage: Mock): + expected_ret_val = {"default": BaseDBAsyncClient("default")} + mocked_conn_storage.get.return_value = expected_ret_val + ret_val = self.conn_handler._get_storage() + self.assertDictEqual(ret_val, expected_ret_val) + + @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) + def test_set_storage(self, mocked_conn_storage: Mock): + mocked_conn_storage.set.return_value = "blah" + new_storage = {"default": BaseDBAsyncClient("default")} + ret_val = self.conn_handler._set_storage(new_storage) + mocked_conn_storage.set.assert_called_once_with(new_storage) + self.assertEqual(ret_val, mocked_conn_storage.set.return_value) + + @patch("tortoise.connection.ConnectionHandler._get_storage") + @patch("tortoise.connection.copy") + def test_copy_storage(self, mocked_copy: Mock, mocked_get_storage: Mock): + expected_ret_value = {"default": BaseDBAsyncClient("default")} + mocked_get_storage.return_value = expected_ret_value + mocked_copy.return_value = expected_ret_value.copy() + ret_val = self.conn_handler._copy_storage() + mocked_get_storage.assert_called_once() + mocked_copy.assert_called_once_with(mocked_get_storage.return_value) + self.assertDictEqual(ret_val, expected_ret_value) + self.assertNotEqual(id(expected_ret_value), id(ret_val)) + + @patch("tortoise.connection.ConnectionHandler._get_storage") + def test_clear_storage(self, mocked_get_storage: Mock): + self.conn_handler._clear_storage() + mocked_get_storage.assert_called_once() + mocked_get_storage.return_value.clear.assert_called_once() + + @patch("tortoise.connection.importlib.import_module") + def test_discover_client_class_proper_impl(self, mocked_import_module: Mock): + mocked_import_module.return_value = Mock(client_class="some_class") + client_class = self.conn_handler._discover_client_class("blah") + mocked_import_module.assert_called_once_with("blah") + self.assertEqual(client_class, "some_class") + + @patch("tortoise.connection.importlib.import_module") + def test_discover_client_class_improper_impl(self, mocked_import_module: Mock): + del mocked_import_module.return_value.client_class + engine = "some_engine" + with self.assertRaises( + ConfigurationError, msg=f'Backend for engine "{engine}" does not implement db client' + ): + _ = self.conn_handler._discover_client_class(engine) + + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + def test_get_db_info_present(self, mocked_db_config: Mock): + expected_ret_val = {"HOST": "some_host", "PORT": "1234"} + mocked_db_config.return_value = {"default": expected_ret_val} + ret_val = self.conn_handler._get_db_info("default") + self.assertEqual(ret_val, expected_ret_val) + + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + def test_get_db_info_not_present(self, mocked_db_config: Mock): + mocked_db_config.return_value = {"default": {"HOST": "some_host", "PORT": "1234"}} + conn_alias = "blah" + with self.assertRaises( + ConfigurationError, + msg=f"Unable to get db settings for alias '{conn_alias}'. Please " + f"check if the config dict contains this alias and try again", + ): + _ = self.conn_handler._get_db_info(conn_alias) + + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + @patch("tortoise.connection.ConnectionHandler.get") + async def test_init_connections_no_db_create(self, mocked_get: Mock, mocked_db_config: Mock): + conn_1, conn_2 = AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient) + mocked_get.side_effect = [conn_1, conn_2] + mocked_db_config.return_value = { + "default": {"HOST": "some_host", "PORT": "1234"}, + "other": {"HOST": "some_other_host", "PORT": "1234"}, + } + await self.conn_handler._init_connections() + mocked_db_config.assert_called_once() + mocked_get.assert_has_calls([call("default"), call("other")], any_order=True) + conn_1.db_create.assert_not_awaited() + conn_2.db_create.assert_not_awaited() + + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + @patch("tortoise.connection.ConnectionHandler.get") + async def test_init_connections_db_create(self, mocked_get: Mock, mocked_db_config: Mock): + self.conn_handler._create_db = True + conn_1, conn_2 = AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient) + mocked_get.side_effect = [conn_1, conn_2] + mocked_db_config.return_value = { + "default": {"HOST": "some_host", "PORT": "1234"}, + "other": {"HOST": "some_other_host", "PORT": "1234"}, + } + await self.conn_handler._init_connections() + mocked_db_config.assert_called_once() + mocked_get.assert_has_calls([call("default"), call("other")], any_order=True) + conn_1.db_create.assert_awaited_once() + conn_2.db_create.assert_awaited_once() + + @patch("tortoise.connection.ConnectionHandler._get_db_info") + @patch("tortoise.connection.expand_db_url") + @patch("tortoise.connection.ConnectionHandler._discover_client_class") + def test_create_connection_db_info_str( + self, + mocked_discover_client_class: Mock, + mocked_expand_db_url: Mock, + mocked_get_db_info: Mock, + ): + alias = "default" + mocked_get_db_info.return_value = "some_db_url" + mocked_expand_db_url.return_value = { + "engine": "some_engine", + "credentials": {"cred_key": "some_val"}, + } + expected_client_class = Mock(return_value="some_connection") + mocked_discover_client_class.return_value = expected_client_class + expected_db_params = {"cred_key": "some_val", "connection_name": alias} + + ret_val = self.conn_handler._create_connection(alias) + + mocked_get_db_info.assert_called_once_with(alias) + mocked_expand_db_url.assert_called_once_with("some_db_url") + mocked_discover_client_class.assert_called_once_with("some_engine") + expected_client_class.assert_called_once_with(**expected_db_params) + self.assertEqual(ret_val, "some_connection") + + @patch("tortoise.connection.ConnectionHandler._get_db_info") + @patch("tortoise.connection.expand_db_url") + @patch("tortoise.connection.ConnectionHandler._discover_client_class") + def test_create_connection_db_info_not_str( + self, + mocked_discover_client_class: Mock, + mocked_expand_db_url: Mock, + mocked_get_db_info: Mock, + ): + alias = "default" + mocked_get_db_info.return_value = { + "engine": "some_engine", + "credentials": {"cred_key": "some_val"}, + } + expected_client_class = Mock(return_value="some_connection") + mocked_discover_client_class.return_value = expected_client_class + expected_db_params = {"cred_key": "some_val", "connection_name": alias} + + ret_val = self.conn_handler._create_connection(alias) + + mocked_get_db_info.assert_called_once_with(alias) + mocked_expand_db_url.assert_not_called() + mocked_discover_client_class.assert_called_once_with("some_engine") + expected_client_class.assert_called_once_with(**expected_db_params) + self.assertEqual(ret_val, "some_connection") + + @patch("tortoise.connection.ConnectionHandler._get_storage") + @patch("tortoise.connection.ConnectionHandler._create_connection") + def test_get_alias_present(self, mocked_create_connection: Mock, mocked_get_storage: Mock): + mocked_get_storage.return_value = {"default": "some_connection"} + ret_val = self.conn_handler.get("default") + mocked_get_storage.assert_called_once() + mocked_create_connection.assert_not_called() + self.assertEqual(ret_val, "some_connection") + + @patch("tortoise.connection.ConnectionHandler._get_storage") + @patch("tortoise.connection.ConnectionHandler._create_connection") + def test_get_alias_not_present(self, mocked_create_connection: Mock, mocked_get_storage: Mock): + mocked_get_storage.return_value = {"default": "some_connection"} + expected_final_dict = {**mocked_get_storage.return_value, "other": "some_other_connection"} + mocked_create_connection.return_value = "some_other_connection" + ret_val = self.conn_handler.get("other") + mocked_get_storage.assert_called_once() + mocked_create_connection.assert_called_once_with("other") + self.assertEqual(ret_val, "some_other_connection") + self.assertDictEqual(mocked_get_storage.return_value, expected_final_dict) + + @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) + @patch("tortoise.connection.ConnectionHandler._copy_storage") + def test_set(self, mocked_copy_storage: Mock, mocked_conn_storage: Mock): + mocked_copy_storage.return_value = {} + expected_storage = {"default": "some_conn"} + mocked_conn_storage.set.return_value = "some_token" + ret_val = self.conn_handler.set("default", "some_conn") + mocked_copy_storage.assert_called_once() + self.assertEqual(ret_val, mocked_conn_storage.set.return_value) + self.assertDictEqual(expected_storage, mocked_copy_storage.return_value) + + @patch("tortoise.connection.ConnectionHandler._get_storage") + def test_discard(self, mocked_get_storage: Mock): + mocked_get_storage.return_value = {"default": "some_conn"} + ret_val = self.conn_handler.discard("default") + self.assertEqual(ret_val, "some_conn") + self.assertDictEqual({}, mocked_get_storage.return_value) + + @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) + @patch("tortoise.connection.ConnectionHandler._get_storage") + def test_reset(self, mocked_get_storage: Mock, mocked_conn_storage: Mock): + first_config = {"other": "some_other_conn", "default": "diff_conn"} + second_config = {"default": "some_conn"} + mocked_get_storage.side_effect = [first_config, second_config] + final_storage = {"default": "some_conn", "other": "some_other_conn"} + self.conn_handler.reset("some_token") # type: ignore + mocked_get_storage.assert_has_calls([call(), call()]) + mocked_conn_storage.reset.assert_called_once_with("some_token") + self.assertDictEqual(final_storage, second_config) + + @patch("tortoise.connection.ConnectionHandler._get_storage") + def test_all(self, mocked_get_storage: Mock): + mocked_get_storage.return_value = {"default": "some_conn", "other": "some_other_conn"} + expected_result = ["some_conn", "some_other_conn"] + ret_val = self.conn_handler.all() + self.assertEqual(ret_val, expected_result) + + @patch("tortoise.connection.ConnectionHandler._get_storage") + @patch("tortoise.connection.ConnectionHandler.discard") + async def test_close_all_with_discard(self, mocked_discard: Mock, mocked_get_storage: Mock): + storage = { + "default": AsyncMock(spec=BaseDBAsyncClient), + "other": AsyncMock(spec=BaseDBAsyncClient), + } + mocked_get_storage.return_value = storage + await self.conn_handler.close_all(discard=True) + for mock_obj in storage.values(): + mock_obj.close.assert_awaited_once() + mocked_discard.assert_has_calls([call("default"), call("other")], any_order=True) + + @patch("tortoise.connection.ConnectionHandler._get_storage") + @patch("tortoise.connection.ConnectionHandler.discard") + async def test_close_all_without_discard(self, mocked_discard: Mock, mocked_get_storage: Mock): + storage = { + "default": AsyncMock(spec=BaseDBAsyncClient), + "other": AsyncMock(spec=BaseDBAsyncClient), + } + mocked_get_storage.return_value = storage + await self.conn_handler.close_all() + for mock_obj in storage.values(): + mock_obj.close.assert_awaited_once() + mocked_discard.assert_not_called() diff --git a/tortoise/connection.py b/tortoise/connection.py index 972f46391..77899f054 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -1,8 +1,8 @@ import asyncio import contextvars -import copy import importlib from contextvars import ContextVar +from copy import copy from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union from tortoise.backends.base.config_generator import expand_db_url @@ -42,8 +42,12 @@ def db_config(self) -> "DBConfigType": def _get_storage(self) -> Dict[str, "BaseDBAsyncClient"]: return self._conn_storage.get() + def _set_storage(self, new_storage: Dict[str, "BaseDBAsyncClient"]) -> contextvars.Token: + # Should be used only for testing purposes. + return self._conn_storage.set(new_storage) + def _copy_storage(self) -> Dict[str, "BaseDBAsyncClient"]: - return copy.copy(self._get_storage()) + return copy(self._get_storage()) def _clear_storage(self) -> None: self._get_storage().clear() @@ -91,10 +95,6 @@ def get(self, conn_alias: str) -> "BaseDBAsyncClient": storage[conn_alias] = connection return connection - def _set_storage(self, new_storage: Dict[str, "BaseDBAsyncClient"]) -> contextvars.Token: - # Should be used only for testing purposes. - return self._conn_storage.set(new_storage) - def set(self, conn_alias: str, conn) -> contextvars.Token: storage_copy = self._copy_storage() storage_copy[conn_alias] = conn From ab7051eda3518c8b275d33eb33f5f72ff3862d58 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Wed, 19 Jan 2022 00:40:23 +0530 Subject: [PATCH 14/25] - Added docs for the new connection interface --- docs/connections.rst | 58 +++++++++++++++++++++++++++++++++++++ docs/databases.rst | 1 + docs/reference.rst | 1 + tests/test_connection.py | 2 +- tortoise/__init__.py | 15 ++-------- tortoise/connection.py | 62 ++++++++++++++++++++++++++++++++++++++-- 6 files changed, 122 insertions(+), 17 deletions(-) create mode 100644 docs/connections.rst diff --git a/docs/connections.rst b/docs/connections.rst new file mode 100644 index 000000000..4e9448fdc --- /dev/null +++ b/docs/connections.rst @@ -0,0 +1,58 @@ +.. _connections: + +=========== +Connections +=========== + +This document describes how to access the underlying connection object (:ref:`BaseDBAsyncClient`) for the aliases defined +as part of the DB config passed to the :meth:`Tortoise.init` call. + +Below is a simple code snippet which shows how the interface can be accessed: + +.. code-block:: python3 + + # connections is a singleton instance of the ConnectionHandler class and serves as the + # entrypoint to access all connection management APIs. + from tortoise import connections + + + # Assume that this is the Tortoise configuration used + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": "example.sqlite3"}, + } + }, + "apps": { + "events": {"models": ["__main__"], "default_connection": "default"} + }, + } + ) + + conn: BaseDBAsyncClient = connections.get("default") + try: + await conn.execute_query('SELECT * FROM "event"') + except OperationalError: + print("Expected it to fail") + +.. important:: + The :ref:`tortoise.connection.ConnectionHandler` class has been implemented with the singleton + pattern in mind and so when the ORM initializes, a singleton instance of this class + ``tortoise.connection.connections`` is created automatically and lives in memory up until the lifetime of the app. + Any attempt to modify or override its behaviour at runtime is risky and not recommended. + + +Please refer to :ref:`this example` for a detailed demonstration of how developers can use +this in practice. + + +API Reference +=========== + +.. _connection_handler: + +.. automodule:: tortoise.connection + :members: + :undoc-members: \ No newline at end of file diff --git a/docs/databases.rst b/docs/databases.rst index 4eac300af..65b78375e 100644 --- a/docs/databases.rst +++ b/docs/databases.rst @@ -208,6 +208,7 @@ handle complex objects. } ) +.. _base_db_client: Base DB client ============== diff --git a/docs/reference.rst b/docs/reference.rst index 67721f10b..fcfad2186 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -16,6 +16,7 @@ Reference functions expressions transactions + connections exceptions signals migration diff --git a/tests/test_connection.py b/tests/test_connection.py index f0b718a6d..fd1039baa 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -213,7 +213,7 @@ def test_set(self, mocked_copy_storage: Mock, mocked_conn_storage: Mock): mocked_copy_storage.return_value = {} expected_storage = {"default": "some_conn"} mocked_conn_storage.set.return_value = "some_token" - ret_val = self.conn_handler.set("default", "some_conn") + ret_val = self.conn_handler.set("default", "some_conn") # type: ignore mocked_copy_storage.assert_called_once() self.assertEqual(ret_val, mocked_conn_storage.set.return_value) self.assertDictEqual(expected_storage, mocked_copy_storage.return_value) diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 9fedec309..2ae6ae33f 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -37,11 +37,11 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient: """ Returns the connection by name. - :raises KeyError: If connection name does not exist. + :raises ConfigurationError: If connection name does not exist. .. warning:: This is deprecated and will be removed in a future release. Please use - :meth:`tortoise.connection.connections.get` instead. + :meth:`connections.get` instead. """ return connections.get(connection_name) @@ -339,17 +339,6 @@ def split_reference(reference: str) -> Tuple[str, str]: model._meta.filters.update(get_m2m_filters(field, m2m_object)) related_model._meta.add_field(backward_relation_name, m2m_relation) - @classmethod - def _discover_client_class(cls, engine: str) -> Type[BaseDBAsyncClient]: - # Let exception bubble up for transparency - engine_module = importlib.import_module(engine) - - try: - client_class = engine_module.client_class - except AttributeError: - raise ConfigurationError(f'Backend for engine "{engine}" does not implement db client') - return client_class - @classmethod def _discover_models( cls, models_path: Union[ModuleType, str], app_label: str diff --git a/tortoise/connection.py b/tortoise/connection.py index 77899f054..73b49f883 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -31,6 +31,14 @@ async def _init(self, db_config: "DBConfigType", create_db: bool): @property def db_config(self) -> "DBConfigType": + """ + Returns the DB config with which the + :meth:`Tortoise.init` method was called. + + :raises ConfigurationError: + If this property is accessed before calling the + :meth:`Tortoise.init` method. + """ if self._db_config is None: raise ConfigurationError( "DB configuration not initialised. Make sure to call " @@ -87,6 +95,15 @@ def _create_connection(self, conn_alias: str) -> "BaseDBAsyncClient": return connection def get(self, conn_alias: str) -> "BaseDBAsyncClient": + """ + Used for accessing the low-level connection object + (:class:`BaseDBAsyncClient`) for the + given alias. + + :param conn_alias: The alias for which the connection has to be fetched + + :raises ConfigurationError: If the connection alias does not exist. + """ storage: Dict[str, "BaseDBAsyncClient"] = self._get_storage() try: return storage[conn_alias] @@ -95,15 +112,46 @@ def get(self, conn_alias: str) -> "BaseDBAsyncClient": storage[conn_alias] = connection return connection - def set(self, conn_alias: str, conn) -> contextvars.Token: + def set(self, conn_alias: str, conn_obj: "BaseDBAsyncClient") -> contextvars.Token: + """ + Sets the given alias to the provided connection object. + + :param conn_alias: The alias to set the connection for. + :param conn_obj: The connection object that needs to be set for this alias. + + .. note:: + This method copies the storage from the `current context`, updates the + ``conn_alias`` with the provided ``conn_obj`` and sets the updated storage + in a `new context` and therefore returns a ``contextvars.Token`` in order to restore + the original context storage. + """ storage_copy = self._copy_storage() - storage_copy[conn_alias] = conn + storage_copy[conn_alias] = conn_obj return self._conn_storage.set(storage_copy) def discard(self, conn_alias: str) -> Optional["BaseDBAsyncClient"]: + """ + Discards the given alias from the storage in the `current context`. + + :param conn_alias: The alias for which the connection object should be discarded. + + .. important:: + Make sure to have called ``conn.close()`` for the provided alias before calling + this method else there would be a connection leak (dangling connection). + """ return self._get_storage().pop(conn_alias, None) - def reset(self, token: contextvars.Token): + def reset(self, token: contextvars.Token) -> None: + """ + Resets the storage state to the `context` associated with the provided token. After + resetting storage state, any additional `connections` created in the `old context` are + copied into the `current context`. + + :param token: + The token corresponding to the `context` to which the storage state has to + be reset. Typically, this token is obtained by calling the + :meth:`set` method of this class. + """ current_storage = self._get_storage() self._conn_storage.reset(token) prev_storage = self._get_storage() @@ -112,11 +160,19 @@ def reset(self, token: contextvars.Token): prev_storage[alias] = conn def all(self) -> List["BaseDBAsyncClient"]: + """Returns a list of connection objects from the storage in the `current context`.""" # Returning a list here so as to avoid accidental # mutation of the underlying storage dict return list(self._get_storage().values()) async def close_all(self, discard: bool = False) -> None: + """ + Closes all connections in the storage in the `current context`. + + :param discard: + If ``True``, the connection object is discarded from the storage + after being closed. + """ tasks = [conn.close() for conn in self._get_storage().values()] await asyncio.gather(*tasks) if discard: From d99aeda6acce106ded2bf991b8dbaf0faa2180f0 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Wed, 19 Jan 2022 00:45:50 +0530 Subject: [PATCH 15/25] - Fixed codacy check --- tortoise/connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tortoise/connection.py b/tortoise/connection.py index 73b49f883..0fcda0238 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -32,8 +32,8 @@ async def _init(self, db_config: "DBConfigType", create_db: bool): @property def db_config(self) -> "DBConfigType": """ - Returns the DB config with which the - :meth:`Tortoise.init` method was called. + Returns the DB config with which the :meth:`Tortoise.init` + method was called. :raises ConfigurationError: If this property is accessed before calling the From 9adabf78a18fec22cf2e79083b909153ed592215 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Wed, 19 Jan 2022 00:50:59 +0530 Subject: [PATCH 16/25] - Fixed codacy check --- tortoise/connection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tortoise/connection.py b/tortoise/connection.py index 0fcda0238..2d9daca4d 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -31,8 +31,7 @@ async def _init(self, db_config: "DBConfigType", create_db: bool): @property def db_config(self) -> "DBConfigType": - """ - Returns the DB config with which the :meth:`Tortoise.init` + """Returns the DB config with which the :meth:`Tortoise.init` method was called. :raises ConfigurationError: From 1edd8135644676edde5b008154e7a0f9a5e4d695 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Wed, 19 Jan 2022 01:12:43 +0530 Subject: [PATCH 17/25] - Fixed codacy check --- tortoise/connection.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tortoise/connection.py b/tortoise/connection.py index 2d9daca4d..4e6127a4e 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -31,8 +31,11 @@ async def _init(self, db_config: "DBConfigType", create_db: bool): @property def db_config(self) -> "DBConfigType": - """Returns the DB config with which the :meth:`Tortoise.init` - method was called. + """ + Return the DB config. + + This is the same config passed to the + :meth:`Tortoise.init` method while initialization. :raises ConfigurationError: If this property is accessed before calling the @@ -95,6 +98,8 @@ def _create_connection(self, conn_alias: str) -> "BaseDBAsyncClient": def get(self, conn_alias: str) -> "BaseDBAsyncClient": """ + Return the connection object for the given alias, creating it if needed. + Used for accessing the low-level connection object (:class:`BaseDBAsyncClient`) for the given alias. @@ -142,6 +147,8 @@ def discard(self, conn_alias: str) -> Optional["BaseDBAsyncClient"]: def reset(self, token: contextvars.Token) -> None: """ + Reset the underlying storage to the previous context state. + Resets the storage state to the `context` associated with the provided token. After resetting storage state, any additional `connections` created in the `old context` are copied into the `current context`. From d7e70485eb276061aa59225984f4552fd4b33583 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Wed, 19 Jan 2022 01:17:58 +0530 Subject: [PATCH 18/25] - Fixed codacy check --- tortoise/connection.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/tortoise/connection.py b/tortoise/connection.py index 4e6127a4e..e8fb8f846 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -31,8 +31,7 @@ async def _init(self, db_config: "DBConfigType", create_db: bool): @property def db_config(self) -> "DBConfigType": - """ - Return the DB config. + """Return the DB config. This is the same config passed to the :meth:`Tortoise.init` method while initialization. @@ -97,8 +96,7 @@ def _create_connection(self, conn_alias: str) -> "BaseDBAsyncClient": return connection def get(self, conn_alias: str) -> "BaseDBAsyncClient": - """ - Return the connection object for the given alias, creating it if needed. + """Return the connection object for the given alias, creating it if needed. Used for accessing the low-level connection object (:class:`BaseDBAsyncClient`) for the @@ -117,8 +115,7 @@ def get(self, conn_alias: str) -> "BaseDBAsyncClient": return connection def set(self, conn_alias: str, conn_obj: "BaseDBAsyncClient") -> contextvars.Token: - """ - Sets the given alias to the provided connection object. + """Sets the given alias to the provided connection object. :param conn_alias: The alias to set the connection for. :param conn_obj: The connection object that needs to be set for this alias. @@ -134,8 +131,7 @@ def set(self, conn_alias: str, conn_obj: "BaseDBAsyncClient") -> contextvars.Tok return self._conn_storage.set(storage_copy) def discard(self, conn_alias: str) -> Optional["BaseDBAsyncClient"]: - """ - Discards the given alias from the storage in the `current context`. + """Discards the given alias from the storage in the `current context`. :param conn_alias: The alias for which the connection object should be discarded. @@ -146,8 +142,7 @@ def discard(self, conn_alias: str) -> Optional["BaseDBAsyncClient"]: return self._get_storage().pop(conn_alias, None) def reset(self, token: contextvars.Token) -> None: - """ - Reset the underlying storage to the previous context state. + """Reset the underlying storage to the previous context state. Resets the storage state to the `context` associated with the provided token. After resetting storage state, any additional `connections` created in the `old context` are @@ -172,8 +167,7 @@ def all(self) -> List["BaseDBAsyncClient"]: return list(self._get_storage().values()) async def close_all(self, discard: bool = False) -> None: - """ - Closes all connections in the storage in the `current context`. + """Closes all connections in the storage in the `current context`. :param discard: If ``True``, the connection object is discarded from the storage From c04d1ebcc7edfadabe8fa4109d10c32b58a584cf Mon Sep 17 00:00:00 2001 From: Aditya N Date: Wed, 19 Jan 2022 01:21:56 +0530 Subject: [PATCH 19/25] - Fixed codacy check --- tortoise/connection.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tortoise/connection.py b/tortoise/connection.py index e8fb8f846..4e6127a4e 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -31,7 +31,8 @@ async def _init(self, db_config: "DBConfigType", create_db: bool): @property def db_config(self) -> "DBConfigType": - """Return the DB config. + """ + Return the DB config. This is the same config passed to the :meth:`Tortoise.init` method while initialization. @@ -96,7 +97,8 @@ def _create_connection(self, conn_alias: str) -> "BaseDBAsyncClient": return connection def get(self, conn_alias: str) -> "BaseDBAsyncClient": - """Return the connection object for the given alias, creating it if needed. + """ + Return the connection object for the given alias, creating it if needed. Used for accessing the low-level connection object (:class:`BaseDBAsyncClient`) for the @@ -115,7 +117,8 @@ def get(self, conn_alias: str) -> "BaseDBAsyncClient": return connection def set(self, conn_alias: str, conn_obj: "BaseDBAsyncClient") -> contextvars.Token: - """Sets the given alias to the provided connection object. + """ + Sets the given alias to the provided connection object. :param conn_alias: The alias to set the connection for. :param conn_obj: The connection object that needs to be set for this alias. @@ -131,7 +134,8 @@ def set(self, conn_alias: str, conn_obj: "BaseDBAsyncClient") -> contextvars.Tok return self._conn_storage.set(storage_copy) def discard(self, conn_alias: str) -> Optional["BaseDBAsyncClient"]: - """Discards the given alias from the storage in the `current context`. + """ + Discards the given alias from the storage in the `current context`. :param conn_alias: The alias for which the connection object should be discarded. @@ -142,7 +146,8 @@ def discard(self, conn_alias: str) -> Optional["BaseDBAsyncClient"]: return self._get_storage().pop(conn_alias, None) def reset(self, token: contextvars.Token) -> None: - """Reset the underlying storage to the previous context state. + """ + Reset the underlying storage to the previous context state. Resets the storage state to the `context` associated with the provided token. After resetting storage state, any additional `connections` created in the `old context` are @@ -167,7 +172,8 @@ def all(self) -> List["BaseDBAsyncClient"]: return list(self._get_storage().values()) async def close_all(self, discard: bool = False) -> None: - """Closes all connections in the storage in the `current context`. + """ + Closes all connections in the storage in the `current context`. :param discard: If ``True``, the connection object is discarded from the storage From 32f87e4c5306ce5ad7c5776eb422bc65dfca8b88 Mon Sep 17 00:00:00 2001 From: Aditya N Date: Wed, 19 Jan 2022 14:35:47 +0530 Subject: [PATCH 20/25] - Refactored Tortoise.close_connections to use the connections.close_all method - Updated relevant tests and docs --- tests/backends/test_postgres.py | 2 +- tests/test_connection.py | 55 ++++++++++++++----------- tortoise/__init__.py | 8 +++- tortoise/connection.py | 19 +++++---- tortoise/contrib/aiohttp/__init__.py | 2 +- tortoise/contrib/blacksheep/__init__.py | 2 +- tortoise/contrib/fastapi/__init__.py | 2 +- tortoise/contrib/quart/__init__.py | 4 +- tortoise/contrib/sanic/__init__.py | 2 +- tortoise/contrib/starlette/__init__.py | 2 +- 10 files changed, 57 insertions(+), 41 deletions(-) diff --git a/tests/backends/test_postgres.py b/tests/backends/test_postgres.py index 19fa2543b..dc18343d7 100644 --- a/tests/backends/test_postgres.py +++ b/tests/backends/test_postgres.py @@ -37,7 +37,7 @@ async def test_schema(self): await Tortoise.generate_schemas() tournament = await Tournament.create(name="Test") - await Tortoise.close_connections() + await connections.close_all() del self.db_config["connections"]["models"]["credentials"]["schema"] await Tortoise.init(self.db_config) diff --git a/tests/test_connection.py b/tests/test_connection.py index fd1039baa..7f6841330 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -237,35 +237,44 @@ def test_reset(self, mocked_get_storage: Mock, mocked_conn_storage: Mock): mocked_conn_storage.reset.assert_called_once_with("some_token") self.assertDictEqual(final_storage, second_config) - @patch("tortoise.connection.ConnectionHandler._get_storage") - def test_all(self, mocked_get_storage: Mock): - mocked_get_storage.return_value = {"default": "some_conn", "other": "some_other_conn"} + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + @patch("tortoise.connection.ConnectionHandler.get") + def test_all(self, mocked_get: Mock, mocked_db_config: Mock): + db_config = {"default": "some_conn", "other": "some_other_conn"} + + def side_effect_callable(alias): + return db_config[alias] + + mocked_get.side_effect = side_effect_callable + mocked_db_config.return_value = db_config expected_result = ["some_conn", "some_other_conn"] ret_val = self.conn_handler.all() + mocked_db_config.assert_called_once() + mocked_get.assert_has_calls([call("default"), call("other")], any_order=True) self.assertEqual(ret_val, expected_result) - @patch("tortoise.connection.ConnectionHandler._get_storage") + @patch("tortoise.connection.ConnectionHandler.all") @patch("tortoise.connection.ConnectionHandler.discard") - async def test_close_all_with_discard(self, mocked_discard: Mock, mocked_get_storage: Mock): - storage = { - "default": AsyncMock(spec=BaseDBAsyncClient), - "other": AsyncMock(spec=BaseDBAsyncClient), - } - mocked_get_storage.return_value = storage - await self.conn_handler.close_all(discard=True) - for mock_obj in storage.values(): + @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) + async def test_close_all_with_discard( + self, mocked_db_config: Mock, mocked_discard: Mock, mocked_all: Mock + ): + all_conn = [AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient)] + db_config = {"default": "some_config", "other": "some_other_config"} + mocked_all.return_value = all_conn + mocked_db_config.return_value = db_config + await self.conn_handler.close_all() + mocked_all.assert_called_once() + mocked_db_config.assert_called_once() + for mock_obj in all_conn: mock_obj.close.assert_awaited_once() mocked_discard.assert_has_calls([call("default"), call("other")], any_order=True) - @patch("tortoise.connection.ConnectionHandler._get_storage") - @patch("tortoise.connection.ConnectionHandler.discard") - async def test_close_all_without_discard(self, mocked_discard: Mock, mocked_get_storage: Mock): - storage = { - "default": AsyncMock(spec=BaseDBAsyncClient), - "other": AsyncMock(spec=BaseDBAsyncClient), - } - mocked_get_storage.return_value = storage - await self.conn_handler.close_all() - for mock_obj in storage.values(): + @patch("tortoise.connection.ConnectionHandler.all") + async def test_close_all_without_discard(self, mocked_all: Mock): + all_conn = [AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient)] + mocked_all.return_value = all_conn + await self.conn_handler.close_all(discard=False) + mocked_all.assert_called_once() + for mock_obj in all_conn: mock_obj.close.assert_awaited_once() - mocked_discard.assert_not_called() diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 2ae6ae33f..c7a5c424c 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -605,8 +605,12 @@ async def close_connections(cls) -> None: It is required for this to be called on exit, else your event loop may never complete as it is waiting for the connections to die. + + .. warning:: + This is deprecated and will be removed in a future release. Please use + :meth:`connections.close_all` instead. """ - await connections.close_all(discard=True) + await connections.close_all() logger.info("Tortoise-ORM shutdown") @classmethod @@ -645,7 +649,7 @@ async def _drop_databases(cls) -> None: raise ConfigurationError("You have to call .init() first before deleting schemas") # this closes any existing connections/pool if any and clears # the storage - await connections.close_all() + await connections.close_all(discard=False) for conn in connections.all(): await conn.db_delete() connections.discard(conn.connection_name) diff --git a/tortoise/connection.py b/tortoise/connection.py index 4e6127a4e..696ec9118 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -167,22 +167,25 @@ def reset(self, token: contextvars.Token) -> None: def all(self) -> List["BaseDBAsyncClient"]: """Returns a list of connection objects from the storage in the `current context`.""" - # Returning a list here so as to avoid accidental - # mutation of the underlying storage dict - return list(self._get_storage().values()) + # The reason this method iterates over db_config and not over `storage` directly is + # because: assume that someone calls `discard` with a certain alias, and calls this + # method subsequently. The alias which just got discarded from the storage would not + # appear in the returned list though it exists as part of the `db_config`. + return [self.get(alias) for alias in self.db_config] - async def close_all(self, discard: bool = False) -> None: + async def close_all(self, discard: bool = True) -> None: """ Closes all connections in the storage in the `current context`. + All closed connections will be removed from the storage by default. + :param discard: - If ``True``, the connection object is discarded from the storage - after being closed. + If ``False``, the connection object is closed but `retained` in the storage. """ - tasks = [conn.close() for conn in self._get_storage().values()] + tasks = [conn.close() for conn in self.all()] await asyncio.gather(*tasks) if discard: - for alias in tuple(self._get_storage()): + for alias in self.db_config: self.discard(alias) diff --git a/tortoise/contrib/aiohttp/__init__.py b/tortoise/contrib/aiohttp/__init__.py index 4781ec91b..c2a171221 100644 --- a/tortoise/contrib/aiohttp/__init__.py +++ b/tortoise/contrib/aiohttp/__init__.py @@ -85,7 +85,7 @@ async def init_orm(app): # pylint: disable=W0612 await Tortoise.generate_schemas() async def close_orm(app): # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") app.on_startup.append(init_orm) diff --git a/tortoise/contrib/blacksheep/__init__.py b/tortoise/contrib/blacksheep/__init__.py index 32496eb2b..6b832ffbb 100644 --- a/tortoise/contrib/blacksheep/__init__.py +++ b/tortoise/contrib/blacksheep/__init__.py @@ -94,7 +94,7 @@ async def init_orm(context) -> None: # pylint: disable=W0612 @app.on_stop async def close_orm(context) -> None: # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") if add_exception_handlers: diff --git a/tortoise/contrib/fastapi/__init__.py b/tortoise/contrib/fastapi/__init__.py index 83451f0a2..8912cac1f 100644 --- a/tortoise/contrib/fastapi/__init__.py +++ b/tortoise/contrib/fastapi/__init__.py @@ -98,7 +98,7 @@ async def init_orm() -> None: # pylint: disable=W0612 @app.on_event("shutdown") async def close_orm() -> None: # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") if add_exception_handlers: diff --git a/tortoise/contrib/quart/__init__.py b/tortoise/contrib/quart/__init__.py index 14ea890a1..184d3b0e2 100644 --- a/tortoise/contrib/quart/__init__.py +++ b/tortoise/contrib/quart/__init__.py @@ -91,7 +91,7 @@ async def init_orm() -> None: # pylint: disable=W0612 @app.after_serving async def close_orm() -> None: # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") @app.cli.command() # type: ignore @@ -103,7 +103,7 @@ async def inner() -> None: config=config, config_file=config_file, db_url=db_url, modules=modules ) await Tortoise.generate_schemas() - await Tortoise.close_connections() + await connections.close_all() logger.setLevel(logging.DEBUG) loop = asyncio.get_event_loop() diff --git a/tortoise/contrib/sanic/__init__.py b/tortoise/contrib/sanic/__init__.py index 5cb2f7c15..7430dfd52 100644 --- a/tortoise/contrib/sanic/__init__.py +++ b/tortoise/contrib/sanic/__init__.py @@ -87,5 +87,5 @@ async def init_orm(app, loop): # pylint: disable=W0612 @app.listener("after_server_stop") # type:ignore async def close_orm(app, loop): # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") diff --git a/tortoise/contrib/starlette/__init__.py b/tortoise/contrib/starlette/__init__.py index 38c209cdf..a42c63f3f 100644 --- a/tortoise/contrib/starlette/__init__.py +++ b/tortoise/contrib/starlette/__init__.py @@ -87,5 +87,5 @@ async def init_orm() -> None: # pylint: disable=W0612 @app.on_event("shutdown") async def close_orm() -> None: # pylint: disable=W0612 - await Tortoise.close_connections() + await connections.close_all() logger.info("Tortoise-ORM shutdown") From fbfaff81fc5533e56e8a973d5b50f8ed6a8c39bd Mon Sep 17 00:00:00 2001 From: Aditya N Date: Wed, 19 Jan 2022 14:39:51 +0530 Subject: [PATCH 21/25] - Updated docs --- docs/connections.rst | 4 ++-- tortoise/connection.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/connections.rst b/docs/connections.rst index 4e9448fdc..344c89fec 100644 --- a/docs/connections.rst +++ b/docs/connections.rst @@ -44,8 +44,8 @@ Below is a simple code snippet which shows how the interface can be accessed: Any attempt to modify or override its behaviour at runtime is risky and not recommended. -Please refer to :ref:`this example` for a detailed demonstration of how developers can use -this in practice. +Please refer to :ref:`this example` for a detailed demonstration of how this API can be used +in practice. API Reference diff --git a/tortoise/connection.py b/tortoise/connection.py index 696ec9118..ccc56c0ff 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -180,7 +180,7 @@ async def close_all(self, discard: bool = True) -> None: All closed connections will be removed from the storage by default. :param discard: - If ``False``, the connection object is closed but `retained` in the storage. + If ``False``, all connection objects are closed but `retained` in the storage. """ tasks = [conn.close() for conn in self.all()] await asyncio.gather(*tasks) From 58d88433292db4eb1c702d11ca9e97439a63503d Mon Sep 17 00:00:00 2001 From: Aditya N Date: Sat, 5 Mar 2022 23:34:12 +0530 Subject: [PATCH 22/25] - Removed current_transaction_map occurrences in code --- tortoise/__init__.py | 1 - tortoise/transactions.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/tortoise/__init__.py b/tortoise/__init__.py index a3f2cc842..e46917488 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -24,7 +24,6 @@ from tortoise.filters import get_m2m_filters from tortoise.log import logger from tortoise.models import Model, ModelMeta -from tortoise.transactions import current_transaction_map from tortoise.utils import generate_schema_for_client diff --git a/tortoise/transactions.py b/tortoise/transactions.py index dd86c28d1..f8a01eed3 100644 --- a/tortoise/transactions.py +++ b/tortoise/transactions.py @@ -4,7 +4,6 @@ from tortoise import connections from tortoise.exceptions import ParamsError -current_transaction_map: dict = {} if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.base.client import BaseDBAsyncClient, TransactionContext @@ -14,7 +13,6 @@ def _get_connection(connection_name: Optional[str]) -> "BaseDBAsyncClient": - if connection_name: connection = connections.get(connection_name) elif len(connections.db_config) == 1: From ee0a1893a768b6079bb0e77a968755d7b62bb86d Mon Sep 17 00:00:00 2001 From: Aditya N Date: Thu, 17 Mar 2022 00:12:07 +0530 Subject: [PATCH 23/25] - Fixed breaking tests due to merge --- tests/backends/test_connection_params.py | 24 +++++++--- tests/schema/test_generate_schema.py | 4 +- tortoise/backends/asyncpg/client.py | 7 +-- tortoise/backends/psycopg/client.py | 61 ++---------------------- 4 files changed, 25 insertions(+), 71 deletions(-) diff --git a/tests/backends/test_connection_params.py b/tests/backends/test_connection_params.py index 1ae4b35f8..594ca48c2 100644 --- a/tests/backends/test_connection_params.py +++ b/tests/backends/test_connection_params.py @@ -14,7 +14,10 @@ async def asyncTearDown(self) -> None: await super().asyncTearDown() async def test_mysql_connection_params(self): - with patch("asyncmy.create_pool", new=AsyncMock()) as mysql_connect: + with patch( + "tortoise.backends.mysql.client.mysql.create_pool", + new=AsyncMock() + ) as mysql_connect: await connections._init( { "models": { @@ -50,7 +53,10 @@ async def test_mysql_connection_params(self): async def test_asyncpg_connection_params(self): try: - with patch("asyncpg.create_pool", new=AsyncMock()) as asyncpg_connect: + with patch( + "tortoise.backends.asyncpg.client.asyncpg.create_pool", + new=AsyncMock() + ) as asyncpg_connect: await connections._init( { "models": { @@ -90,8 +96,13 @@ async def test_asyncpg_connection_params(self): async def test_psycopg_connection_params(self): try: - with patch("psycopg_pool.AsyncConnectionPool.open", new=AsyncMock()) as psycopg_connect: - await Tortoise._init_connections( + with patch( + "tortoise.backends.psycopg.client.PsycopgClient.create_pool", + new=AsyncMock() + ) as patched_create_pool: + mocked_pool = AsyncMock() + patched_create_pool.return_value = mocked_pool + await connections._init( { "models": { "engine": "tortoise.backends.psycopg", @@ -108,8 +119,9 @@ async def test_psycopg_connection_params(self): }, False, ) - - psycopg_connect.assert_awaited_once_with( # nosec + await connections.get("models").create_connection(with_db=True) + patched_create_pool.assert_awaited_once() + mocked_pool.open.assert_awaited_once_with( # nosec wait=True, timeout=1, ) diff --git a/tests/schema/test_generate_schema.py b/tests/schema/test_generate_schema.py index b3a819440..34f6c65a3 100644 --- a/tests/schema/test_generate_schema.py +++ b/tests/schema/test_generate_schema.py @@ -1168,7 +1168,7 @@ async def init_for(self, module: str, safe=False) -> None: "apps": {"models": {"models": [module], "default_connection": "default"}}, } ) - self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split("; ") + self.sqls = get_schema_sql(connections.get("default"), safe).split("; ") except ImportError: raise test.SkipTest("asyncpg not installed") @@ -1194,6 +1194,6 @@ async def init_for(self, module: str, safe=False) -> None: "apps": {"models": {"models": [module], "default_connection": "default"}}, } ) - self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split("; ") + self.sqls = get_schema_sql(connections.get("default"), safe).split("; ") except ImportError: raise test.SkipTest("psycopg not installed") diff --git a/tortoise/backends/asyncpg/client.py b/tortoise/backends/asyncpg/client.py index c647f8776..b60cfe35b 100644 --- a/tortoise/backends/asyncpg/client.py +++ b/tortoise/backends/asyncpg/client.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Callable, List, Optional, Tuple, TypeVar +from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union import asyncpg from asyncpg.transaction import Transaction @@ -11,7 +11,7 @@ ConnectionWrapper, NestedTransactionPooledContext, TransactionContext, - TransactionContextPooled, + TransactionContextPooled, PoolConnectionWrapper, ) from tortoise.backends.base_postgres.client import BasePostgresClient, translate_exceptions from tortoise.exceptions import ( @@ -147,9 +147,6 @@ async def execute_query_dict(self, query: str, values: Optional[list] = None) -> return list(map(dict, await connection.fetch(query, *values))) return list(map(dict, await connection.fetch(query))) - def _in_transaction(self) -> "TransactionContext": - return TransactionContextPooled(TransactionWrapper(self)) - class TransactionWrapper(AsyncpgDBClient, BaseTransactionWrapper): def __init__(self, connection: AsyncpgDBClient) -> None: diff --git a/tortoise/backends/psycopg/client.py b/tortoise/backends/psycopg/client.py index ff358d477..97a53a047 100644 --- a/tortoise/backends/psycopg/client.py +++ b/tortoise/backends/psycopg/client.py @@ -12,6 +12,7 @@ import tortoise.backends.base_postgres.client as postgres_client import tortoise.backends.psycopg.executor as executor import tortoise.exceptions as exceptions +from tortoise.backends.base.client import PoolConnectionWrapper from tortoise.backends.psycopg.schema_generator import PsycopgSchemaGenerator FuncType = typing.Callable[..., typing.Any] @@ -27,61 +28,9 @@ async def release(self, connection: psycopg.AsyncConnection): await self.putconn(connection) -class PoolConnectionWrapper(base_client.PoolConnectionWrapper): - pool: AsyncConnectionPool - # Using _connection instead of connection because of mypy - _connection: typing.Optional[psycopg.AsyncConnection] = None - - def __init__(self, pool: AsyncConnectionPool) -> None: - self.pool = pool - - async def open(self, timeout: float = 30.0): - try: - await self.pool.open(wait=True, timeout=timeout) - except psycopg_pool.PoolTimeout as exception: - raise exceptions.DBConnectionError from exception - - async def close(self, timeout: float = 0.0): - await self.pool.close(timeout=timeout) - - async def __aenter__(self) -> psycopg.AsyncConnection: - if self.pool is None: - raise RuntimeError("Connection pool is not initialized") - - if self._connection: - raise RuntimeError("Connection is already acquired") - else: - self._connection = await self.pool.getconn() - - return await self._connection.__aenter__() - - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - if self._connection: - await self._connection.__aexit__(exc_type, exc_val, exc_tb) - await self.pool.putconn(self._connection) - - self._connection = None - - # TortoiseORM has this interface hardcoded in the tests, so we need to support it - async def acquire(self) -> psycopg.AsyncConnection: - if not self._connection: - await self.__aenter__() - - if not self._connection: - raise RuntimeError("Connection is not acquired") - return self._connection - - # TortoiseORM has this interface hardcoded in the tests, so we need to support it - async def release(self, connection: psycopg.AsyncConnection) -> None: - if self._connection is not connection: - raise RuntimeError("Wrong connection is being released") - await self.__aexit__(None, None, None) - - class PsycopgClient(postgres_client.BasePostgresClient): executor_class: typing.Type[executor.PsycopgExecutor] = executor.PsycopgExecutor schema_generator: typing.Type[PsycopgSchemaGenerator] = PsycopgSchemaGenerator - # _pool: typing.Optional[PoolConnectionWrapper] = None _pool: typing.Optional[AsyncConnectionPool] = None _connection: psycopg.AsyncConnection default_timeout: float = 30 @@ -230,11 +179,7 @@ async def _translate_exceptions(self, func, *args, **kwargs) -> Exception: def acquire_connection( self, ) -> typing.Union[base_client.ConnectionWrapper, PoolConnectionWrapper]: - if not self._pool: - raise exceptions.OperationalError("Connection pool not initialized") - - pool_wrapper: PoolConnectionWrapper = PoolConnectionWrapper(self._pool) - return pool_wrapper + return PoolConnectionWrapper(self) def _in_transaction(self) -> base_client.TransactionContext: return base_client.TransactionContextPooled(TransactionWrapper(self)) @@ -256,7 +201,7 @@ def _in_transaction(self) -> base_client.TransactionContext: return base_client.NestedTransactionPooledContext(self) def acquire_connection(self) -> base_client.ConnectionWrapper: - return base_client.ConnectionWrapper(self._connection, self._lock) + return base_client.ConnectionWrapper(self._lock, self) @postgres_client.translate_exceptions async def start(self) -> None: From 8711b4fe737a7d8c3435ab60f5f5120119ff7d0c Mon Sep 17 00:00:00 2001 From: Aditya N Date: Thu, 17 Mar 2022 00:18:46 +0530 Subject: [PATCH 24/25] - Fixed lint errors --- tests/backends/test_connection_params.py | 9 +++------ tortoise/backends/asyncpg/client.py | 3 ++- tortoise/transactions.py | 1 - 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/backends/test_connection_params.py b/tests/backends/test_connection_params.py index 594ca48c2..ef346760a 100644 --- a/tests/backends/test_connection_params.py +++ b/tests/backends/test_connection_params.py @@ -15,8 +15,7 @@ async def asyncTearDown(self) -> None: async def test_mysql_connection_params(self): with patch( - "tortoise.backends.mysql.client.mysql.create_pool", - new=AsyncMock() + "tortoise.backends.mysql.client.mysql.create_pool", new=AsyncMock() ) as mysql_connect: await connections._init( { @@ -54,8 +53,7 @@ async def test_mysql_connection_params(self): async def test_asyncpg_connection_params(self): try: with patch( - "tortoise.backends.asyncpg.client.asyncpg.create_pool", - new=AsyncMock() + "tortoise.backends.asyncpg.client.asyncpg.create_pool", new=AsyncMock() ) as asyncpg_connect: await connections._init( { @@ -97,8 +95,7 @@ async def test_asyncpg_connection_params(self): async def test_psycopg_connection_params(self): try: with patch( - "tortoise.backends.psycopg.client.PsycopgClient.create_pool", - new=AsyncMock() + "tortoise.backends.psycopg.client.PsycopgClient.create_pool", new=AsyncMock() ) as patched_create_pool: mocked_pool = AsyncMock() patched_create_pool.return_value = mocked_pool diff --git a/tortoise/backends/asyncpg/client.py b/tortoise/backends/asyncpg/client.py index b60cfe35b..69412920b 100644 --- a/tortoise/backends/asyncpg/client.py +++ b/tortoise/backends/asyncpg/client.py @@ -10,8 +10,9 @@ BaseTransactionWrapper, ConnectionWrapper, NestedTransactionPooledContext, + PoolConnectionWrapper, TransactionContext, - TransactionContextPooled, PoolConnectionWrapper, + TransactionContextPooled, ) from tortoise.backends.base_postgres.client import BasePostgresClient, translate_exceptions from tortoise.exceptions import ( diff --git a/tortoise/transactions.py b/tortoise/transactions.py index f8a01eed3..8a1d92b88 100644 --- a/tortoise/transactions.py +++ b/tortoise/transactions.py @@ -4,7 +4,6 @@ from tortoise import connections from tortoise.exceptions import ParamsError - if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.base.client import BaseDBAsyncClient, TransactionContext From 7518a0225aa5746f7da142cabd8ba85ce3f45b2a Mon Sep 17 00:00:00 2001 From: Aditya N Date: Fri, 18 Mar 2022 11:49:02 +0530 Subject: [PATCH 25/25] - Updated CHANGELOG.rst --- CHANGELOG.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6eab884ca..32b224deb 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,9 +13,21 @@ Changelog Added ^^^^^ - Added psycopg backend support. +- Added a new unified and robust connection management interface to access DB connections which includes support for + lazy connection creation and much more. For more details, + check out this `PR `_ Fixed ^^^^^ - Fix `bulk_create` doesn't work correctly with more than 1 update_fields. (#1046) +Deprecated +^^^^^ +- Existing connection management interface and related public APIs which are deprecated: + - `Tortoise.get_connection` + - `Tortoise.close_connections` +Changed +^^^^^ +- Refactored `tortoise.transactions.get_connection` method to `tortoise.transactions._get_connection`. + Note that this method has now been marked **private to this module and is not part of the public API** 0.18.1 ------