Skip to content

Commit

Permalink
fix(optimizer): Avoid merging prefetches when using aliases (#698)
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 authored Jan 26, 2025
1 parent 610e12b commit d95acd6
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 40 deletions.
57 changes: 46 additions & 11 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import dataclasses
import itertools
from collections import defaultdict
from collections import Counter, defaultdict
from collections.abc import Callable
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -58,6 +58,7 @@
from .utils.inspect import (
PrefetchInspector,
get_model_field,
get_model_fields,
get_possible_type_definitions,
)
from .utils.typing import (
Expand Down Expand Up @@ -1035,19 +1036,29 @@ def _get_model_hints(
if pk is not None:
store.only.append(pk.attname)

for f_selections in _get_selections(info, parent_type).values():
field_data = _get_field_data(
f_selections,
object_definition,
schema,
parent_type=parent_type,
info=info,
selections = [
field_data
for f_selection in _get_selections(info, parent_type).values()
if (
field_data := _get_field_data(
f_selection,
object_definition,
schema,
parent_type=parent_type,
info=info,
)
)
if field_data is None:
is not None
]
fields_counter = Counter(field_data[0] for field_data in selections)

for field, f_definition, f_selection, f_info in selections:
# If a field is selected more than once in the query, that means it is being
# aliased. In this case, optimizing it would make one query to affect the other,
# resulting in wrong results for both.
if fields_counter[field] > 1:
continue

field, f_definition, f_selection, f_info = field_data

# Add annotations from the field if they exist
if field_store := _get_hints_from_field(field, f_info=f_info, prefix=prefix):
store |= field_store
Expand Down Expand Up @@ -1089,6 +1100,30 @@ def _get_model_hints(
store.only.extend(inner_store.only)
store.select_related.extend(inner_store.select_related)

# In case we skipped optimization for a relation, we might end up with a new QuerySet
# which would not select its parent relation field on `.only()`, causing n+1 issues.
# Make sure that in this case we also select it.
if level == 0 and store.only and info.path.prev:
own_fk_fields = [
field
for field in get_model_fields(model).values()
if isinstance(field, models.ForeignKey)
]

path = info.path
while path := path.prev:
type_ = schema.get_type_by_name(path.typename)
if not isinstance(type_, StrawberryObjectDefinition):
continue

if not (strawberry_django_type := get_django_definition(type_.origin)):
continue

for field in own_fk_fields:
if field.related_model is strawberry_django_type.model:
store.only.append(field.attname)
break

return store


Expand Down
76 changes: 47 additions & 29 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,6 @@ def test_query_forward_with_fragments(db, gql_client: GraphQLTestClient):
}
... milestoneFrag
}
milestoneAgain: milestone {
name
project {
id
name
}
... milestoneFrag
}
}
}
}
Expand All @@ -341,7 +333,6 @@ def test_query_forward_with_fragments(db, gql_client: GraphQLTestClient):
"nameWithKind": f"{i.kind}: {i.name}",
"nameWithPriority": f"{i.kind}: {i.priority}",
"milestone": m_res,
"milestoneAgain": m_res,
},
)

Expand Down Expand Up @@ -538,12 +529,6 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient):
... milestoneFrag
}
}
otherIssues: issues {
id
milestone {
... milestoneFrag
}
}
}
}
}
Expand All @@ -566,7 +551,6 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient):
"name": p_res["name"],
},
"issues": [],
"otherIssues": [],
}
p_res["milestones"].append(m_res)
for i in IssueFactory.create_batch(3, milestone=m):
Expand All @@ -585,22 +569,10 @@ def test_query_prefetch_with_fragments(db, gql_client: GraphQLTestClient):
},
},
)
m_res["otherIssues"].append(
{
"id": to_base64("IssueType", i.id),
"milestone": {
"id": m_res["id"],
"project": {
"id": p_res["id"],
"name": p_res["name"],
},
},
},
)

assert len(expected) == 3
for e in expected:
with assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 8):
with assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 5):
res = gql_client.query(query, {"node_id": e["id"]})

assert res.data == {"project": e}
Expand Down Expand Up @@ -1089,6 +1061,52 @@ def test_query_nested_connection_with_filter(db, gql_client: GraphQLTestClient):
} == expected


@pytest.mark.django_db(transaction=True)
def test_query_nested_connection_with_filter_and_alias(
db, gql_client: GraphQLTestClient
):
query = """
query TestQuery ($id: GlobalID!) {
milestone(id: $id) {
id
fooIssues: issuesWithFilters (filters: {search: "Foo"}) {
edges {
node {
id
}
}
}
barIssues: issuesWithFilters (filters: {search: "Bar"}) {
edges {
node {
id
}
}
}
}
}
"""

milestone = MilestoneFactory.create()
issue1 = IssueFactory.create(milestone=milestone, name="Foo")
issue2 = IssueFactory.create(milestone=milestone, name="Foo Bar")
issue3 = IssueFactory.create(milestone=milestone, name="Bar Foo")
issue4 = IssueFactory.create(milestone=milestone, name="Bar Bin")

with assert_num_queries(3):
res = gql_client.query(query, {"id": to_base64("MilestoneType", milestone.pk)})

assert isinstance(res.data, dict)
result = res.data["milestone"]
assert isinstance(result, dict)

foo_expected = {to_base64("IssueType", i.pk) for i in [issue1, issue2, issue3]}
assert {edge["node"]["id"] for edge in result["fooIssues"]["edges"]} == foo_expected

bar_expected = {to_base64("IssueType", i.pk) for i in [issue2, issue3, issue4]}
assert {edge["node"]["id"] for edge in result["barIssues"]["edges"]} == bar_expected


@pytest.mark.django_db(transaction=True)
def test_query_with_optimizer_paginated_prefetch():
@strawberry_django.type(Milestone, pagination=True)
Expand Down

0 comments on commit d95acd6

Please sign in to comment.