Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into fix/event-loop-warning
Browse files Browse the repository at this point in the history
  • Loading branch information
Abdeldjalil-H committed Jan 29, 2025
2 parents 8252e46 + 960b1c1 commit c56a179
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 7 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ Changelog

.. rst-class:: emphasize-children

0.25
====

0.25.0 (unreleased)
------
Fixed
^^^^^

Changed
^^^^^^^
- add benchmarks for `get_for_dialect` (#1862)

0.24
====

Expand Down
28 changes: 27 additions & 1 deletion tests/benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@

import pytest

from tests.testmodels import BenchmarkFewFields, BenchmarkManyFields
from tests.testmodels import (
BenchmarkFewFields,
BenchmarkManyFields,
Tournament,
Event,
Team,
DecimalFields,
)
from tortoise.contrib.test import _restore_default, truncate_all_models


Expand Down Expand Up @@ -87,3 +94,22 @@ def _gen():
}

return _gen


@pytest.fixture
def create_team_with_participants() -> None:
async def _create() -> None:
tournament = await Tournament.create(name="New Tournament")
event = await Event.create(name="Test", tournament_id=tournament.id)
team = await Team.create(name="Some Team")
await event.participants.add(team)

asyncio.get_event_loop().run_until_complete(_create())


@pytest.fixture
def create_decimals() -> None:
async def _create() -> None:
await DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7)

asyncio.get_event_loop().run_until_complete(_create())
27 changes: 27 additions & 0 deletions tests/benchmarks/test_expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import asyncio

from tests.testmodels import BenchmarkFewFields, DecimalFields
from tortoise.expressions import F
from tortoise.functions import Count


def test_expressions_count(benchmark, few_fields_benchmark_dataset):
loop = asyncio.get_event_loop()

@benchmark
def bench():
async def _bench():
await BenchmarkFewFields.annotate(text_count=Count("text"))

loop.run_until_complete(_bench())


def test_expressions_f(benchmark, create_decimals):
loop = asyncio.get_event_loop()

@benchmark
def bench():
async def _bench():
await DecimalFields.annotate(d=F("decimal")).all()

loop.run_until_complete(_bench())
32 changes: 32 additions & 0 deletions tests/benchmarks/test_field_attribute_lookup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from tortoise.fields import Field


class MyField(Field):
@property
def MY_PROPERTY(self):
return f"hi from {self.__class__.__name__}!"

OTHER_PROPERTY = "something else"

class _db_property:
def __init__(self, field: "Field"):
self.field = field

@property
def MY_PROPERTY(self):
return f"hi from {self.__class__.__name__} of {self.field.__class__.__name__}!"

class _db_cls_attribute:
MY_PROPERTY = "cls_attribute"


def test_field_attribute_lookup_get_for_dialect(benchmark):
field = MyField()

@benchmark
def bench():
field.get_for_dialect("property", "MY_PROPERTY")
field.get_for_dialect("postgres", "MY_PROPERTY")
field.get_for_dialect("cls_attribute", "MY_PROPERTY")
field.get_for_dialect("property", "OTHER_PROPERTY")
field.get_for_dialect("property", "MY_PROPERTY")
14 changes: 14 additions & 0 deletions tests/benchmarks/test_relations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import asyncio

from tests.testmodels import Event


def test_relations_values_related_m2m(benchmark, create_team_with_participants):
loop = asyncio.get_event_loop()

@benchmark
def bench():
async def _bench():
await Event.all().values("participants__name")

loop.run_until_complete(_bench())
27 changes: 21 additions & 6 deletions tortoise/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,17 @@ def get_db_field_types(self) -> Optional[dict[str, str]]:
"""
if not self.has_db_field: # pragma: nocoverage
return None
default = getattr(self, "SQL_TYPE")
return {
"": getattr(self, "SQL_TYPE"),
"": default,
**{
dialect: _db["SQL_TYPE"]
for dialect, _db in self._get_dialects().items()
if "SQL_TYPE" in _db
dialect: sql_type
for dialect, sql_type in (
(key[4:], self.get_for_dialect(key[4:], "SQL_TYPE"))
for key in dir(self)
if key.startswith("_db_")
)
if sql_type != default
},
}

Expand All @@ -342,8 +347,18 @@ def get_for_dialect(self, dialect: str, key: str) -> Any:
:param dialect: The requested SQL Dialect.
:param key: The attribute/method name.
"""
dialect_data = self._get_dialects().get(dialect, {})
return dialect_data.get(key, getattr(self, key, None))
try:
dialect_cls = getattr(self, f"_db_{dialect}") # throws AttributeError if not present
dialect_value = getattr(dialect_cls, key) # throws AttributeError if not present
except AttributeError:
pass
else: # we have dialect_cls and dialect_value, so lets use it
# it could be that dialect_value is a computed property, like in CharField._db_oracle.SQL_TYPE,
# and therefore one first needs to instantiate dialect_cls
if isinstance(dialect_value, property):
return getattr(dialect_cls(self), key)
return dialect_value
return getattr(self, key, None) # there is nothing special defined, return the value of self

def describe(self, serializable: bool) -> dict:
"""
Expand Down

0 comments on commit c56a179

Please sign in to comment.