Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Handle multiple comprehension targets #13213

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
fn comprehension_scope() {
let TestCase { db, file } = test_case(
"
[x for x in iter1]
[x for x, y in iter1]
",
);

Expand All @@ -690,7 +690,22 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):

let comprehension_symbol_table = index.symbol_table(comprehension_scope_id);

assert_eq!(names(&comprehension_symbol_table), vec!["x"]);
assert_eq!(names(&comprehension_symbol_table), vec!["x", "y"]);

let use_def = index.use_def_map(comprehension_scope_id);
for name in ["x", "y"] {
let definition = use_def
.first_public_definition(
comprehension_symbol_table
.symbol_id_by_name(name)
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
definition.node(&db),
DefinitionKind::Comprehension(_)
));
}
}

/// Test case to validate that the `x` variable used in the comprehension is referencing the
Expand Down Expand Up @@ -730,8 +745,8 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let DefinitionKind::Comprehension(comprehension) = definition.node(&db) else {
panic!("expected generator definition")
};
let ast::Comprehension { target, .. } = comprehension.node();
let name = target.as_name_expr().unwrap().id().as_str();
let target = comprehension.target();
let name = target.id().as_str();

assert_eq!(name, "x");
assert_eq!(target.range(), TextRange::new(23.into(), 24.into()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ impl<'db> SemanticIndexBuilder<'db> {

// The `iter` of the first generator is evaluated in the outer scope, while all subsequent
// nodes are evaluated in the inner scope.
self.add_standalone_expression(&generator.iter);
self.visit_expr(&generator.iter);
self.push_scope(scope);

Expand All @@ -297,6 +298,7 @@ impl<'db> SemanticIndexBuilder<'db> {
}

for generator in generators_iter {
self.add_standalone_expression(&generator.iter);
self.visit_expr(&generator.iter);

self.current_assignment = Some(CurrentAssignment::Comprehension {
Expand Down Expand Up @@ -675,7 +677,11 @@ where
Some(CurrentAssignment::Comprehension { node, first }) => {
self.add_definition(
symbol,
ComprehensionDefinitionNodeRef { node, first },
ComprehensionDefinitionNodeRef {
iterable: &node.iter,
target: name_node,
first,
},
);
}
Some(CurrentAssignment::WithItem(with_item)) => {
Expand Down
37 changes: 20 additions & 17 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ pub(crate) struct ForStmtDefinitionNodeRef<'a> {

#[derive(Copy, Clone, Debug)]
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::Comprehension,
pub(crate) iterable: &'a ast::Expr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if we'll also need access to the if portion of the generator here, but I think we won't; that expression should be handled separately as a Constraint, it's not part of the Definition.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's what I thought as well which is the reason to limit this. It can easily be reverted in the future if there's a need.

pub(crate) target: &'a ast::ExprName,
pub(crate) first: bool,
}

Expand Down Expand Up @@ -211,12 +212,15 @@ impl DefinitionNodeRef<'_> {
target: AstNodeRef::new(parsed, target),
})
}
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => {
DefinitionKind::Comprehension(ComprehensionDefinitionKind {
node: AstNodeRef::new(parsed, node),
first,
})
}
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef {
iterable,
target,
first,
}) => DefinitionKind::Comprehension(ComprehensionDefinitionKind {
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
first,
}),
DefinitionNodeRef::Parameter(parameter) => match parameter {
ast::AnyParameterRef::Variadic(parameter) => {
DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter))
Expand Down Expand Up @@ -262,7 +266,7 @@ impl DefinitionNodeRef<'_> {
iterable: _,
target,
}) => target.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
Self::Parameter(node) => match node {
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
Expand Down Expand Up @@ -313,13 +317,18 @@ impl MatchPatternDefinitionKind {

#[derive(Clone, Debug)]
pub struct ComprehensionDefinitionKind {
node: AstNodeRef<ast::Comprehension>,
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
first: bool,
}

impl ComprehensionDefinitionKind {
pub(crate) fn node(&self) -> &ast::Comprehension {
self.node.node()
pub(crate) fn iterable(&self) -> &ast::Expr {
self.iterable.node()
}

pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
}

pub(crate) fn is_first(&self) -> bool {
Expand Down Expand Up @@ -442,12 +451,6 @@ impl From<&ast::StmtFor> for DefinitionNodeKey {
}
}

impl From<&ast::Comprehension> for DefinitionNodeKey {
fn from(node: &ast::Comprehension) -> Self {
Self(NodeKey::from_node(node))
}
}

impl From<&ast::Parameter> for DefinitionNodeKey {
fn from(node: &ast::Parameter) -> Self {
Self(NodeKey::from_node(node))
Expand Down
84 changes: 55 additions & 29 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ impl<'db> TypeInferenceBuilder<'db> {
}
DefinitionKind::Comprehension(comprehension) => {
self.infer_comprehension_definition(
comprehension.node(),
comprehension.iterable(),
comprehension.target(),
comprehension.is_first(),
definition,
);
Expand Down Expand Up @@ -1544,11 +1545,11 @@ impl<'db> TypeInferenceBuilder<'db> {

/// Infer the type of the `iter` expression of the first comprehension.
fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) {
let mut generators_iter = comprehensions.iter();
let Some(first_generator) = generators_iter.next() else {
let mut comprehensions_iter = comprehensions.iter();
let Some(first_comprehension) = comprehensions_iter.next() else {
unreachable!("Comprehension must contain at least one generator");
};
self.infer_expression(&first_generator.iter);
self.infer_expression(&first_comprehension.iter);
}

fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> {
Expand Down Expand Up @@ -1614,9 +1615,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} = generator;

self.infer_expression(elt);
for comprehension in generators {
self.infer_comprehension(comprehension);
}
self.infer_comprehensions_scope(generators);
}

fn infer_list_comprehension_expression_scope(&mut self, listcomp: &ast::ExprListComp) {
Expand All @@ -1627,9 +1626,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} = listcomp;

self.infer_expression(elt);
for comprehension in generators {
self.infer_comprehension(comprehension);
}
self.infer_comprehensions_scope(generators);
}

fn infer_dict_comprehension_expression_scope(&mut self, dictcomp: &ast::ExprDictComp) {
Expand All @@ -1642,9 +1639,7 @@ impl<'db> TypeInferenceBuilder<'db> {

self.infer_expression(key);
self.infer_expression(value);
for comprehension in generators {
self.infer_comprehension(comprehension);
}
self.infer_comprehensions_scope(generators);
}

fn infer_set_comprehension_expression_scope(&mut self, setcomp: &ast::ExprSetComp) {
Expand All @@ -1655,37 +1650,68 @@ impl<'db> TypeInferenceBuilder<'db> {
} = setcomp;

self.infer_expression(elt);
for comprehension in generators {
self.infer_comprehension(comprehension);
}
self.infer_comprehensions_scope(generators);
}

fn infer_comprehension(&mut self, comprehension: &ast::Comprehension) {
self.infer_definition(comprehension);
for expr in &comprehension.ifs {
self.infer_expression(expr);
fn infer_comprehensions_scope(&mut self, comprehensions: &[ast::Comprehension]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've inherited some awfully confusing naming from the CPython grammar and AST :( If we have [y for x in [] for y in x], that's a single "list comprehension", whose AST node has a generators attribute containing two... Comprehensions, for x in [] and for y in x. So not only is the term "generator" overloaded here from its usual meaning, but we're also using the term "comprehension" itself to mean two entirely different things, one of which is part of the other 🤯

(To be clear, this is not a critique of this PR, just a lament for the state of naming in this part of the AST. We could fix this in our AST, but then we'd be inventing our own naming scheme that doesn't match the CPython AST, and that doesn't seem great either.)

One way to maybe clarify this naming a bit would be to switch from list_comprehension, dict_comprehension, etc all through these method names to instead use listcomp, dictcomp, etc when we are referring to the outer expression, and reserve the full word "comprehension" for the thing actually named Comprehension in the AST, which is just the inner for x in y part. I wouldn't exactly call that clear, but given what we're working with, it might at least be better? It at least matches the names used in the AST and avoids using the word "comprehension" to mean two different things.

In any case, I don't like the names infer_comprehensions_scope and infer_comprehension_scope for these two methods. The individual "Comprehension" nodes inside a list/set/dict/gen-comp are not different scopes, so the use of "scope" here is misleading.

If we go with my suggestion above to rename e.g. infer_list_comprehension_expression to infer_listcomp_expression, then I think we could name these two just infer_comprehensions and infer_comprehension.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the sentiment, thanks for sharing that.

(To be clear, this is not a critique of this PR, just a lament for the state of naming in this part of the AST. We could fix this in our AST, but then we'd be inventing our own naming scheme that doesn't match the CPython AST, and that doesn't seem great either.)

I think, if we want, we can change the names in our AST as we've done so in the past as well (#8064, #6379, #6253, etc.). There are other recommendations here: #6183

One way to maybe clarify this naming a bit would be to switch from list_comprehension, dict_comprehension, etc all through these method names to instead use listcomp, dictcomp, etc when we are referring to the outer expression, and reserve the full word "comprehension" for the thing actually named Comprehension in the AST, which is just the inner for x in y part.

From my perspective, both listcomp and list_comprehension seems equivalent as the former seems like a short version of the latter. Is the reason for recommending listcomp because of it's presence in the CPython AST?

In any case, I don't like the names infer_comprehensions_scope and infer_comprehension_scope for these two methods. The individual "Comprehension" nodes inside a list/set/dict/gen-comp are not different scopes, so the use of "scope" here is misleading.

Makes sense. I can change this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not renamed the list_comprehension to listcomp as I feel like the "expression" prefix in list_comprehension_expression gives a signal that this is different than a simple "comprehension". Regardless, I don't have a strong opinion, so I'm happy to rename it if you think listcomp is better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, mostly what my suggestion was aiming for was to hew very closely to the AST naming, in hopes that this would at least give some guideposts to readers. But I'm OK with your approach and relying on _expression to distinguish.

let mut generators_iter = comprehensions.iter();
let Some(first_generator) = generators_iter.next() else {
unreachable!("Comprehension must contain at least one generator");
};
self.infer_comprehension_scope(first_generator, true);
for generator in generators_iter {
self.infer_comprehension_scope(generator, false);
dhruvmanila marked this conversation as resolved.
Show resolved Hide resolved
}
}

fn infer_comprehension_definition(
&mut self,
comprehension: &ast::Comprehension,
is_first: bool,
definition: Definition<'db>,
) {
fn infer_comprehension_scope(&mut self, comprehension: &ast::Comprehension, is_first: bool) {
let ast::Comprehension {
range: _,
target,
iter,
ifs: _,
ifs,
is_async: _,
} = comprehension;

if !is_first {
self.infer_expression(iter);
}
// TODO(dhruvmanila): The target type should be inferred based on the iter type instead.
let target_ty = self.infer_expression(target);
// TODO more complex assignment targets
if let ast::Expr::Name(name) = target {
self.infer_definition(name);
} else {
self.infer_expression(target);
}
for expr in ifs {
self.infer_expression(expr);
}
}

fn infer_comprehension_definition(
&mut self,
iterable: &ast::Expr,
target: &ast::ExprName,
is_first: bool,
definition: Definition<'db>,
) {
if !is_first {
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db, expression);
self.extend(result);
let _iterable_ty = self
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
}
// TODO(dhruvmanila): The iter type for the first comprehension is coming from the
// enclosing scope.

// TODO(dhruvmanila): The target type should be inferred based on the iter type instead,
// similar to how it's done in `infer_for_statement_definition`.
let target_ty = Type::Unknown;

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), target_ty);
self.types.definitions.insert(definition, target_ty);
}

Expand Down
Loading