Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A unified, robust and bug-free connection management interface for the ORM #1001

Merged
merged 31 commits into from
Mar 20, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c2a398b
- Initial working commit for refactored connection management.
blazing-gig Dec 5, 2021
c5c7d31
- Refactored test_init.py to conform to the new connection management…
blazing-gig Dec 5, 2021
1fd9dac
- Fixed linting errors (flake8 and mypy)
blazing-gig Dec 5, 2021
faa4d87
- Modified comment
blazing-gig Dec 5, 2021
79b187a
- Code cleanup
blazing-gig Dec 6, 2021
8184432
- Fixed style issues
blazing-gig Dec 7, 2021
d3e7403
- Fixed mysql breaking tests
blazing-gig Dec 19, 2021
212c8f7
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Dec 20, 2021
adacec3
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Jan 11, 2022
18912ef
- Merged latest head from develop and fixed breaking tests
blazing-gig Jan 13, 2022
9ed0e71
- Fixed style and lint issues
blazing-gig Jan 13, 2022
69d947e
- Fixed code quality issues (codacy)
blazing-gig Jan 13, 2022
3b3b2e5
- Refactored core and test files to use the new connection interface
blazing-gig Jan 13, 2022
a7f47ba
- Fixed style and lint errors
blazing-gig Jan 13, 2022
e8e529b
- Added unit tests for the new connection interface
blazing-gig Jan 18, 2022
ab7051e
- Added docs for the new connection interface
blazing-gig Jan 18, 2022
d99aeda
- Fixed codacy check
blazing-gig Jan 18, 2022
9adabf7
- Fixed codacy check
blazing-gig Jan 18, 2022
1edd813
- Fixed codacy check
blazing-gig Jan 18, 2022
d7e7048
- Fixed codacy check
blazing-gig Jan 18, 2022
c04d1eb
- Fixed codacy check
blazing-gig Jan 18, 2022
32f87e4
- Refactored Tortoise.close_connections to use the connections.close_…
blazing-gig Jan 19, 2022
fbfaff8
- Updated docs
blazing-gig Jan 19, 2022
0468d81
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Feb 16, 2022
522945e
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Feb 25, 2022
58d8843
- Removed current_transaction_map occurrences in code
blazing-gig Mar 5, 2022
092448b
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Mar 5, 2022
ee0a189
- Fixed breaking tests due to merge
blazing-gig Mar 16, 2022
8711b4f
- Fixed lint errors
blazing-gig Mar 16, 2022
7518a02
- Updated CHANGELOG.rst
blazing-gig Mar 18, 2022
eb6fb02
Merge remote-tracking branch 'origin/develop' into feat/connection_mgmt
aditya-n-invcr Mar 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions tests/model_setup/test__models__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from unittest.mock import AsyncMock, patch

from tortoise import Tortoise
from tortoise import Tortoise, connections
from tortoise.contrib import test
from tortoise.exceptions import ConfigurationError
from tortoise.utils import get_schema_sql
Expand All @@ -15,7 +15,6 @@ class TestGenerateSchema(test.SimpleTestCase):
async def setUp(self):
try:
Tortoise.apps = {}
Tortoise._connections = {}
Tortoise._inited = False
except ConfigurationError:
pass
Expand All @@ -27,7 +26,6 @@ async def setUp(self):
]

async def tearDown(self):
Tortoise._connections = {}
await Tortoise._reset_apps()

async def init_for(self, module: str, safe=False) -> None:
Expand All @@ -47,7 +45,7 @@ async def init_for(self, module: str, safe=False) -> None:
"apps": {"models": {"models": [module], "default_connection": "default"}},
}
)
self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split(";\n")
self.sqls = get_schema_sql(connections.get("default"), safe).split(";\n")

def get_sql(self, text: str) -> str:
return str(re.sub(r"[ \t\n\r]+", " ", [sql for sql in self.sqls if text in sql][0]))
Expand Down
16 changes: 7 additions & 9 deletions tests/model_setup/test_init.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from tortoise import Tortoise
from tortoise import Tortoise, connections
from tortoise.contrib import test
from tortoise.exceptions import ConfigurationError

Expand All @@ -9,14 +9,12 @@ class TestInitErrors(test.SimpleTestCase):
async def setUp(self):
try:
Tortoise.apps = {}
Tortoise._connections = {}
Tortoise._inited = False
except ConfigurationError:
pass
Tortoise._inited = False

async def tearDown(self):
await Tortoise.close_connections()
await Tortoise._reset_apps()

async def test_basic_init(self):
Expand All @@ -34,7 +32,7 @@ async def test_basic_init(self):
}
)
self.assertIn("models", Tortoise.apps)
self.assertIsNotNone(Tortoise.get_connection("default"))
self.assertIsNotNone(connections.get("default"))

async def test_empty_modules_init(self):
with self.assertWarnsRegex(RuntimeWarning, 'Module "tests.model_setup" has no models'):
Expand Down Expand Up @@ -216,7 +214,7 @@ async def test_default_connection_init(self):
}
)
self.assertIn("models", Tortoise.apps)
self.assertIsNotNone(Tortoise.get_connection("default"))
self.assertIsNotNone(connections.get("default"))

async def test_db_url_init(self):
await Tortoise.init(
Expand All @@ -228,14 +226,14 @@ async def test_db_url_init(self):
}
)
self.assertIn("models", Tortoise.apps)
self.assertIsNotNone(Tortoise.get_connection("default"))
self.assertIsNotNone(connections.get("default"))

async def test_shorthand_init(self):
await Tortoise.init(
db_url=f"sqlite://{':memory:'}", modules={"models": ["tests.testmodels"]}
)
self.assertIn("models", Tortoise.apps)
self.assertIsNotNone(Tortoise.get_connection("default"))
self.assertIsNotNone(connections.get("default"))

async def test_init_wrong_connection_engine(self):
with self.assertRaisesRegex(ImportError, "tortoise.backends.test"):
Expand Down Expand Up @@ -324,13 +322,13 @@ async def test_init_config_file_wrong_extension(self):
async def test_init_json_file(self):
await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.json")
self.assertIn("models", Tortoise.apps)
self.assertIsNotNone(Tortoise.get_connection("default"))
self.assertIsNotNone(connections.get("default"))

@test.skipIf(os.name == "nt", "path issue on Windows")
async def test_init_yaml_file(self):
await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.yaml")
self.assertIn("models", Tortoise.apps)
self.assertIsNotNone(Tortoise.get_connection("default"))
self.assertIsNotNone(connections.get("default"))

async def test_generate_schema_without_init(self):
with self.assertRaisesRegex(
Expand Down
12 changes: 5 additions & 7 deletions tests/schema/test_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from unittest.mock import AsyncMock, patch

from tortoise import Tortoise
from tortoise import Tortoise, connections
from tortoise.contrib import test
from tortoise.exceptions import ConfigurationError
from tortoise.utils import get_schema_sql
Expand All @@ -12,7 +12,6 @@ class TestGenerateSchema(test.SimpleTestCase):
async def setUp(self):
try:
Tortoise.apps = {}
Tortoise._connections = {}
Tortoise._inited = False
except ConfigurationError:
pass
Expand All @@ -24,7 +23,6 @@ async def setUp(self):
]

async def tearDown(self):
Tortoise._connections = {}
await Tortoise._reset_apps()

async def init_for(self, module: str, safe=False) -> None:
Expand All @@ -42,7 +40,7 @@ async def init_for(self, module: str, safe=False) -> None:
"apps": {"models": {"models": [module], "default_connection": "default"}},
}
)
self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split(";\n")
self.sqls = get_schema_sql(connections.get("default"), safe).split(";\n")

def get_sql(self, text: str) -> str:
return re.sub(r"[ \t\n\r]+", " ", " ".join([sql for sql in self.sqls if text in sql]))
Expand Down Expand Up @@ -415,7 +413,7 @@ async def init_for(self, module: str, safe=False) -> None:
"apps": {"models": {"models": [module], "default_connection": "default"}},
}
)
self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split("; ")
self.sqls = get_schema_sql(connections.get("default"), safe).split("; ")
except ImportError:
raise test.SkipTest("aiomysql not installed")

Expand Down Expand Up @@ -498,7 +496,7 @@ async def test_schema_no_db_constraint(self):
async def test_schema(self):
self.maxDiff = None
await self.init_for("tests.schema.models_schema_create")
sql = get_schema_sql(Tortoise.get_connection("default"), safe=False)
sql = get_schema_sql(connections.get("default"), safe=False)
self.assertEqual(
sql.strip(),
"""
Expand Down Expand Up @@ -793,7 +791,7 @@ async def init_for(self, module: str, safe=False) -> None:
"apps": {"models": {"models": [module], "default_connection": "default"}},
}
)
self.sqls = get_schema_sql(Tortoise._connections["default"], safe).split("; ")
self.sqls = get_schema_sql(connections.get("default"), safe).split("; ")
except ImportError:
raise test.SkipTest("asyncpg not installed")

Expand Down
6 changes: 3 additions & 3 deletions tests/test_two_databases.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from tests.testmodels import Event, EventTwo, TeamTwo, Tournament
from tortoise import Tortoise
from tortoise import Tortoise, connections
from tortoise.contrib import test
from tortoise.exceptions import OperationalError, ParamsError
from tortoise.transactions import in_transaction
Expand All @@ -17,8 +17,8 @@ async def setUp(self):
}
await Tortoise.init(merged_config, _create_db=True)
await Tortoise.generate_schemas()
self.db = Tortoise.get_connection("models")
self.second_db = Tortoise.get_connection("events")
self.db = connections.get("models")
self.second_db = connections.get("events")

async def tearDown(self):
await Tortoise._drop_databases()
Expand Down
14 changes: 7 additions & 7 deletions tests/utils/test_run_async.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from unittest import TestCase, skipIf

from tortoise import Tortoise, run_async
from tortoise import Tortoise, connections, run_async


@skipIf(os.name == "nt", "stuck with Windows")
Expand All @@ -12,25 +12,25 @@ def setUp(self):
async def init(self):
await Tortoise.init(db_url="sqlite://:memory:", modules={"models": []})
self.somevalue = 2
self.assertNotEqual(Tortoise._connections, {})
self.assertNotEqual(connections._get_storage(), {})

async def init_raise(self):
await Tortoise.init(db_url="sqlite://:memory:", modules={"models": []})
self.somevalue = 3
self.assertNotEqual(Tortoise._connections, {})
self.assertNotEqual(connections._get_storage(), {})
raise Exception("Some exception")

def test_run_async(self):
self.assertEqual(Tortoise._connections, {})
self.assertEqual(connections._get_storage(), {})
self.assertEqual(self.somevalue, 1)
run_async(self.init())
self.assertEqual(Tortoise._connections, {})
self.assertEqual(connections._get_storage(), {})
self.assertEqual(self.somevalue, 2)

def test_run_async_raised(self):
self.assertEqual(Tortoise._connections, {})
self.assertEqual(connections._get_storage(), {})
self.assertEqual(self.somevalue, 1)
with self.assertRaises(Exception):
run_async(self.init_raise())
self.assertEqual(Tortoise._connections, {})
self.assertEqual(connections._get_storage(), {})
self.assertEqual(self.somevalue, 3)
23 changes: 13 additions & 10 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.backends.base.config_generator import expand_db_url, generate_config
from tortoise.connection import connections
from tortoise.exceptions import ConfigurationError
from tortoise.fields.relational import (
BackwardFKRelation,
Expand Down Expand Up @@ -40,7 +41,7 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient:

:raises KeyError: If connection name does not exist.
"""
return cls._connections[connection_name]
return connections.get(connection_name)

@classmethod
def describe_model(
Expand Down Expand Up @@ -542,7 +543,7 @@ async def init(
:raises ConfigurationError: For any configuration error
"""
if cls._inited:
await cls.close_connections()
await connections.close_all(discard=True)
await cls._reset_apps()
if int(bool(config) + bool(config_file) + bool(db_url)) != 1:
raise ConfigurationError(
Expand Down Expand Up @@ -595,7 +596,7 @@ async def init(
)

cls._init_timezone(use_tz, timezone)
await cls._init_connections(connections_config, _create_db)
await connections._init(connections_config, _create_db)
cls._init_apps(apps_config)
cls._init_routers(routers)

Expand Down Expand Up @@ -643,7 +644,6 @@ async def _reset_apps(cls) -> None:
if isinstance(model, ModelMeta):
model._meta.default_connection = None
cls.apps.clear()
current_transaction_map.clear()

@classmethod
async def generate_schemas(cls, safe: bool = True) -> None:
Expand All @@ -658,7 +658,7 @@ async def generate_schemas(cls, safe: bool = True) -> None:
"""
if not cls._inited:
raise ConfigurationError("You have to call .init() first before generating schemas")
for connection in cls._connections.values():
for connection in connections.all():
await generate_schema_for_client(connection, safe)

@classmethod
Expand All @@ -671,10 +671,13 @@ async def _drop_databases(cls) -> None:
"""
if not cls._inited:
raise ConfigurationError("You have to call .init() first before deleting schemas")
for connection in cls._connections.values():
await connection.close()
await connection.db_delete()
cls._connections = {}
# this closes any existing connections/pool if any and clears
# the storage
await connections.close_all()
for conn in connections.all():
await conn.db_delete()
connections.discard(conn.connection_name)

await cls._reset_apps()

@classmethod
Expand Down Expand Up @@ -706,7 +709,7 @@ async def do_stuff():
try:
loop.run_until_complete(coro)
finally:
loop.run_until_complete(Tortoise.close_connections())
loop.run_until_complete(connections.close_all(discard=True))


__version__ = "0.18.0"
6 changes: 3 additions & 3 deletions tortoise/backends/asyncpg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def db_delete(self) -> None:
await self.close()

def acquire_connection(self) -> Union["PoolConnectionWrapper", "ConnectionWrapper"]:
return PoolConnectionWrapper(self._pool)
return PoolConnectionWrapper(self)

def _in_transaction(self) -> "TransactionContext":
return TransactionContextPooled(TransactionWrapper(self))
Expand Down Expand Up @@ -209,13 +209,13 @@ def __init__(self, connection: AsyncpgDBClient) -> None:
self.connection_name = connection.connection_name
self.transaction: Transaction = None
self._finalized = False
self._parent = connection
self._parent: AsyncpgDBClient = connection

def _in_transaction(self) -> "TransactionContext":
return NestedTransactionPooledContext(self)

def acquire_connection(self) -> "ConnectionWrapper":
return ConnectionWrapper(self._connection, self._lock)
return ConnectionWrapper(self._lock, self)

@translate_exceptions
async def execute_many(self, query: str, values: list) -> None:
Expand Down
Loading