Skip to content

Commit

Permalink
perf: AnyOfFilter shouldn't impose distinct by default unless joini…
Browse files Browse the repository at this point in the history
…ng on another table
  • Loading branch information
rodrigoalmeidaee committed Oct 17, 2024
1 parent 9ece7d2 commit 309ba84
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
9 changes: 8 additions & 1 deletion drf_kit/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def clean(self, value):
class AnyOfFilter(MultipleChoiceFilter):
field_class = _OpenChoiceField

def __init__(self, *args, **kwargs):
kwargs.setdefault("distinct", None)
super().__init__(*args, **kwargs)

def filter(self, qs, value):
if not value:
return qs
Expand All @@ -105,7 +109,10 @@ def filter(self, qs, value):
query_filter = Q(**predicate)
qs = self.get_method(qs)(query_filter)

return qs.distinct() if self.distinct else qs
possibly_filtering_on_relationship = "__" in self.field_name
force_distinct = possibly_filtering_on_relationship if self.distinct is None else self.distinct

return qs.distinct() if force_distinct else qs


class AllOfFilter(MultipleChoiceFilter):
Expand Down
1 change: 1 addition & 0 deletions test_app/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Meta:

class WizardFilterSet(filters.BaseFilterSet):
spell_name = filters.AllOfFilter(field_name="spell_casts__spell__name")
any_spell_name = filters.AnyOfFilter(field_name="spell_casts__spell__name")

class Meta:
model = models.Wizard
Expand Down
4 changes: 4 additions & 0 deletions test_app/tests/tests_views/tests_filter_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def test_filter_multiple_spell_name_conjoined_return_wizards(self):
response = self.client.get(self.url, data={"spell_name": [self.spell_a.name, self.spell_b.name]})
self.assertResponseItems(expected_items=[self.wizard_a], response=response)

def test_filter_multiple_spell_name_any_of_through_relationship(self):
response = self.client.get(self.url, data={"any_spell_name": [self.spell_a.name, self.spell_b.name]})
self.assertResponseItems(expected_items=[self.wizard_a, self.wizard_b], response=response)

def test_filter_multiple_spell_name_disjointed_not_return_wizards(self):
response = self.client.get(self.url, data={"spell_name": [self.spell_a.name, self.spell_noise.name]})
self.assertResponseItems(expected_items=[], response=response)
Expand Down

0 comments on commit 309ba84

Please sign in to comment.