diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 48d7e0b97..d0cc0925c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,7 @@ Changelog ------- * Fixed invalid ``var IN ()`` SQL generated using ``__in=`` and ``__not_in`` filters. * Fix bug with order_by on nested fields +* Fix joining with self by reverse-foreign-key for filtering and annotation 0.15.20 ------ diff --git a/tests/test_relations.py b/tests/test_relations.py index 9091f6ae7..3f1da043e 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -192,6 +192,33 @@ async def test_self_ref(self): self.assertEqual(await root2.full_hierarchy__async_for(), ROOT_TEXT) self.assertEqual(await root2.full_hierarchy__fetch_related(), ROOT_TEXT) + async def test_self_ref_filter_by_child(self): + self.maxDiff = None + root = await Employee.create(name="Root") + await Employee.create(name="1. First H1", manager=root) + await Employee.create(name="2. Second H1", manager=root) + + root2 = await Employee.get(team_members__name="1. First H1") + self.assertEqual(root.id, root2.id) + + async def test_self_ref_filter_both(self): + self.maxDiff = None + root = await Employee.create(name="Root") + await Employee.create(name="1. First H1", manager=root) + await Employee.create(name="2. Second H1", manager=root) + + root2 = await Employee.get(name="Root", team_members__name="1. First H1") + self.assertEqual(root.id, root2.id) + + async def test_self_ref_annotate(self): + self.maxDiff = None + root = await Employee.create(name="Root") + await Employee.create(name="1. First H1", manager=root) + await Employee.create(name="2. Second H1", manager=root) + + root_ann = await Employee.get(name="Root").annotate(num_team_members=Count("team_members")) + self.assertEqual(root_ann.num_team_members, 2) + async def test_prefetch_related_fk(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) diff --git a/tortoise/query_utils.py b/tortoise/query_utils.py index 80a71d347..9c25718c0 100644 --- a/tortoise/query_utils.py +++ b/tortoise/query_utils.py @@ -48,9 +48,7 @@ def _get_joins_for_related_field( ) -> List[Tuple[Table, Criterion]]: required_joins = [] - related_table = ( - related_field.model_class._meta.basetable - ) # .as_(f"{table.get_table_name()}__{related_field_name}") + related_table: Table = related_field.model_class._meta.basetable if isinstance(related_field, ManyToManyFieldInstance): through_table = Table(related_field.through) required_joins.append( @@ -68,6 +66,8 @@ def _get_joins_for_related_field( ) ) elif isinstance(related_field, BackwardFKRelation): + if table == related_table: + related_table = related_table.as_(f"{table.get_table_name()}__{related_field_name}") required_joins.append( ( related_table,