From 29e71cb719cfaa383bb91510f682739a1664b9f1 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 09:13:56 -0300 Subject: [PATCH 01/15] Solve function type earlier when elaborating method call expression --- .../src/elaborator/expressions.rs | 30 ++++++++------- compiler/noirc_frontend/src/hir_def/expr.rs | 37 ++++++++++--------- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 33af075aebd..be14fb2c4ed 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -458,6 +458,20 @@ impl<'context> Elaborator<'context> { None }; + let call_span = Span::from(object_span.start()..method_name_span.end()); + let location = Location::new(call_span, self.file); + + let (function_id, function_name) = method_ref.clone().into_function_id_and_name( + object_type.clone(), + generics.clone(), + location, + self.interner, + ); + + let func_type = + self.type_check_variable(function_name.clone(), function_id, generics.clone()); + self.interner.push_expr_type(function_id, func_type.clone()); + // These arguments will be given to the desugared function call. // Compared to the method arguments, they also contain the object. let mut function_args = Vec::with_capacity(method_call.arguments.len() + 1); @@ -472,10 +486,7 @@ impl<'context> Elaborator<'context> { function_args.push((typ, arg, span)); } - let call_span = Span::from(object_span.start()..method_name_span.end()); - let location = Location::new(call_span, self.file); let method = method_call.method_name; - let turbofish_generics = generics.clone(); let is_macro_call = method_call.is_macro_call; let method_call = HirMethodCallExpression { method, object, arguments, location, generics }; @@ -485,18 +496,9 @@ impl<'context> Elaborator<'context> { // Desugar the method call into a normal, resolved function call // so that the backend doesn't need to worry about methods // TODO: update object_type here? - let ((function_id, function_name), function_call) = method_call.into_function_call( - method_ref, - object_type, - is_macro_call, - location, - self.interner, - ); - - let func_type = - self.type_check_variable(function_name, function_id, turbofish_generics); - self.interner.push_expr_type(function_id, func_type.clone()); + let function_call = + method_call.into_function_call(function_id, is_macro_call, location); self.interner .add_function_reference(func_id, Location::new(method_name_span, self.file)); diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index 9b3bf4962bb..5ac228a56d6 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -225,24 +225,15 @@ impl HirMethodReference { } } } -} -impl HirMethodCallExpression { - /// Converts a method call into a function call - /// - /// Returns ((func_var_id, func_var), call_expr) - pub fn into_function_call( - mut self, - method: HirMethodReference, + pub fn into_function_id_and_name( + self, object_type: Type, - is_macro_call: bool, + generics: Option>, location: Location, interner: &mut NodeInterner, - ) -> ((ExprId, HirIdent), HirCallExpression) { - let mut arguments = vec![self.object]; - arguments.append(&mut self.arguments); - - let (id, impl_kind) = match method { + ) -> (ExprId, HirIdent) { + let (id, impl_kind) = match self { HirMethodReference::FuncId(func_id) => { (interner.function_definition_id(func_id), ImplKind::NotATraitMethod) } @@ -261,10 +252,22 @@ impl HirMethodCallExpression { } }; let func_var = HirIdent { location, id, impl_kind }; - let func = interner.push_expr(HirExpression::Ident(func_var.clone(), self.generics)); + let func = interner.push_expr(HirExpression::Ident(func_var.clone(), generics)); interner.push_expr_location(func, location.span, location.file); - let expr = HirCallExpression { func, arguments, location, is_macro_call }; - ((func, func_var), expr) + (func, func_var) + } +} + +impl HirMethodCallExpression { + pub fn into_function_call( + mut self, + func: ExprId, + is_macro_call: bool, + location: Location, + ) -> HirCallExpression { + let mut arguments = vec![self.object]; + arguments.append(&mut self.arguments); + HirCallExpression { func, arguments, location, is_macro_call } } } From 6894d20251449ead6371b0ed4378d6bcf6e504c3 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 10:20:25 -0300 Subject: [PATCH 02/15] feat: lambda parameter hints for method calls by unifying self type --- .../src/elaborator/expressions.rs | 72 ++++++++++++++++--- compiler/noirc_frontend/src/tests.rs | 31 ++++++++ 2 files changed, 93 insertions(+), 10 deletions(-) diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index be14fb2c4ed..72713f5ae9d 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -52,7 +52,7 @@ impl<'context> Elaborator<'context> { ExpressionKind::If(if_) => self.elaborate_if(*if_), ExpressionKind::Variable(variable) => return self.elaborate_variable(variable), ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple), - ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda), + ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda, None), ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr), ExpressionKind::Quote(quote) => self.elaborate_quote(quote, expr.span), ExpressionKind::Comptime(comptime, _) => { @@ -472,6 +472,17 @@ impl<'context> Elaborator<'context> { self.type_check_variable(function_name.clone(), function_id, generics.clone()); self.interner.push_expr_type(function_id, func_type.clone()); + // Try to unify the object type with the first argument of the function. + // The reason to do this is that many methods that take a lambda will yield `self` or part of `self` + // as a parameter. By unifying `self` with the first argument we'll potentially get more + // concrete types in the arguments that are function types, which will later be passed as + // lambda parameter hints. + if let Type::Function(args, _, _, _) = &func_type { + if !args.is_empty() { + let _ = args[0].unify(&object_type); + } + } + // These arguments will be given to the desugared function call. // Compared to the method arguments, they also contain the object. let mut function_args = Vec::with_capacity(method_call.arguments.len() + 1); @@ -479,9 +490,28 @@ impl<'context> Elaborator<'context> { function_args.push((object_type.clone(), object, object_span)); - for arg in method_call.arguments { + for (arg_index, arg) in method_call.arguments.into_iter().enumerate() { let span = arg.span; - let (arg, typ) = self.elaborate_expression(arg); + let (arg, typ) = if let ExpressionKind::Lambda(lambda) = arg.kind { + let type_hint = if let Type::Function(func_args, _, _, _) = &func_type { + if let Some(Type::Function(func_args, _, _, _)) = + func_args.get(arg_index + 1) + { + Some(func_args) + } else { + None + } + } else { + None + }; + let (hir_expr, typ) = self.elaborate_lambda(*lambda, type_hint); + let id = self.interner.push_expr(hir_expr); + self.interner.push_expr_location(id, span, self.file); + self.interner.push_expr_type(id, typ.clone()); + (id, typ) + } else { + self.elaborate_expression(arg) + }; arguments.push(arg); function_args.push((typ, arg, span)); } @@ -848,19 +878,41 @@ impl<'context> Elaborator<'context> { (HirExpression::Tuple(element_ids), Type::Tuple(element_types)) } - fn elaborate_lambda(&mut self, lambda: Lambda) -> (HirExpression, Type) { + /// For elaborating a lambda we might get `parameters_type_hints`. These come from a potential + /// call that has this lambda as the argument. + /// The parameter type hints will be the types of the function type corresponding to the lambda argument. + fn elaborate_lambda( + &mut self, + lambda: Lambda, + parameters_type_hints: Option<&Vec>, + ) -> (HirExpression, Type) { self.push_scope(); let scope_index = self.scopes.current_scope_index(); self.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index }); let mut arg_types = Vec::with_capacity(lambda.parameters.len()); - let parameters = vecmap(lambda.parameters, |(pattern, typ)| { - let parameter = DefinitionKind::Local(None); - let typ = self.resolve_inferred_type(typ); - arg_types.push(typ.clone()); - (self.elaborate_pattern(pattern, typ.clone(), parameter, true), typ) - }); + let parameters: Vec<_> = lambda + .parameters + .into_iter() + .enumerate() + .map(|(index, (pattern, typ))| { + let parameter = DefinitionKind::Local(None); + let typ = self.resolve_inferred_type(typ); + + // If there's a parameter type hint, use it to unify the argument type + if let Some(parameter_type_hint) = + parameters_type_hints.and_then(|hints| hints.get(index)) + { + // We don't error here because eventually the lambda type will be checked against + // the call that contains it, which would then produce an error if this didn't unify. + let _ = typ.unify(parameter_type_hint); + } + + arg_types.push(typ.clone()); + (self.elaborate_pattern(pattern, typ.clone(), parameter, true), typ) + }) + .collect(); let return_type = self.resolve_inferred_type(lambda.return_type); let body_span = lambda.body.span; diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 637b15e7197..33bf8df126e 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -3978,3 +3978,34 @@ fn checks_visibility_of_trait_related_to_trait_impl_on_method_call() { "#; assert_no_errors(src); } + +#[test] +fn infers_lambda_argument_from_call_function_type() { + let src = r#" + struct Foo { + value: Field, + } + + impl Foo { + fn foo(self) -> Field { + self.value + } + } + + struct Box { + value: T, + } + + impl Box { + fn map(self, f: fn(T) -> U) -> Box { + Box { value: f(self.value) } + } + } + + fn main() { + let box = Box { value: Foo { value: 1 } }; + let _ = box.map(|foo| foo.foo()); + } + "#; + assert_no_errors(src); +} From b543252186aaaf25241088781b6f5e2aa181bd7a Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 10:28:54 -0300 Subject: [PATCH 03/15] Simplify lambda parameter types in the stdlib --- noir_stdlib/src/array/mod.nr | 2 +- noir_stdlib/src/meta/expr.nr | 101 +++++++++++++---------------------- noir_stdlib/src/meta/mod.nr | 5 +- 3 files changed, 39 insertions(+), 69 deletions(-) diff --git a/noir_stdlib/src/array/mod.nr b/noir_stdlib/src/array/mod.nr index 47dc3ca7bb9..85cc0580aae 100644 --- a/noir_stdlib/src/array/mod.nr +++ b/noir_stdlib/src/array/mod.nr @@ -157,7 +157,7 @@ where /// } /// ``` pub fn sort(self) -> Self { - self.sort_via(|a: T, b: T| a <= b) + self.sort_via(|a, b| a <= b) } } diff --git a/noir_stdlib/src/meta/expr.nr b/noir_stdlib/src/meta/expr.nr index 7538b26dc44..a1663135c20 100644 --- a/noir_stdlib/src/meta/expr.nr +++ b/noir_stdlib/src/meta/expr.nr @@ -285,33 +285,31 @@ impl Expr { } comptime fn modify_array(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_array().map(|exprs: [Expr]| { + expr.as_array().map(|exprs| { let exprs = modify_expressions(exprs, f); new_array(exprs) }) } comptime fn modify_assert(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_assert().map(|expr: (Expr, Option)| { - let (predicate, msg) = expr; + expr.as_assert().map(|(predicate, msg)| { let predicate = predicate.modify(f); - let msg = msg.map(|msg: Expr| msg.modify(f)); + let msg = msg.map(|msg| msg.modify(f)); new_assert(predicate, msg) }) } comptime fn modify_assert_eq(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_assert_eq().map(|expr: (Expr, Expr, Option)| { - let (lhs, rhs, msg) = expr; + expr.as_assert_eq().map(|(lhs, rhs, msg)| { let lhs = lhs.modify(f); let rhs = rhs.modify(f); - let msg = msg.map(|msg: Expr| msg.modify(f)); + let msg = msg.map(|msg| msg.modify(f)); new_assert_eq(lhs, rhs, msg) }) } comptime fn modify_assign(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_assign().map(|expr: (Expr, Expr)| { + expr.as_assign().map(|expr| { let (lhs, rhs) = expr; let lhs = lhs.modify(f); let rhs = rhs.modify(f); @@ -320,8 +318,7 @@ comptime fn modify_assign(expr: Expr, f: fn[Env](Expr) -> Option) -> } comptime fn modify_binary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_binary_op().map(|expr: (Expr, BinaryOp, Expr)| { - let (lhs, op, rhs) = expr; + expr.as_binary_op().map(|(lhs, op, rhs)| { let lhs = lhs.modify(f); let rhs = rhs.modify(f); new_binary_op(lhs, op, rhs) @@ -329,34 +326,29 @@ comptime fn modify_binary_op(expr: Expr, f: fn[Env](Expr) -> Option) } comptime fn modify_block(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_block().map(|exprs: [Expr]| { + expr.as_block().map(|exprs| { let exprs = modify_expressions(exprs, f); new_block(exprs) }) } comptime fn modify_cast(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_cast().map(|expr: (Expr, UnresolvedType)| { - let (expr, typ) = expr; + expr.as_cast().map(|(expr, typ)| { let expr = expr.modify(f); new_cast(expr, typ) }) } comptime fn modify_comptime(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_comptime().map(|exprs: [Expr]| { - let exprs = exprs.map(|expr: Expr| expr.modify(f)); + expr.as_comptime().map(|exprs| { + let exprs = exprs.map(|expr| expr.modify(f)); new_comptime(exprs) }) } comptime fn modify_constructor(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_constructor().map(|expr: (UnresolvedType, [(Quoted, Expr)])| { - let (typ, fields) = expr; - let fields = fields.map(|field: (Quoted, Expr)| { - let (name, value) = field; - (name, value.modify(f)) - }); + expr.as_constructor().map(|(typ, fields)| { + let fields = fields.map(|(name, value)| (name, value.modify(f))); new_constructor(typ, fields) }) } @@ -365,27 +357,24 @@ comptime fn modify_function_call( expr: Expr, f: fn[Env](Expr) -> Option, ) -> Option { - expr.as_function_call().map(|expr: (Expr, [Expr])| { - let (function, arguments) = expr; + expr.as_function_call().map(|(function, arguments)| { let function = function.modify(f); - let arguments = arguments.map(|arg: Expr| arg.modify(f)); + let arguments = arguments.map(|arg| arg.modify(f)); new_function_call(function, arguments) }) } comptime fn modify_if(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_if().map(|expr: (Expr, Expr, Option)| { - let (condition, consequence, alternative) = expr; + expr.as_if().map(|(condition, consequence, alternative)| { let condition = condition.modify(f); let consequence = consequence.modify(f); - let alternative = alternative.map(|alternative: Expr| alternative.modify(f)); + let alternative = alternative.map(|alternative| alternative.modify(f)); new_if(condition, consequence, alternative) }) } comptime fn modify_index(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_index().map(|expr: (Expr, Expr)| { - let (object, index) = expr; + expr.as_index().map(|(object, index)| { let object = object.modify(f); let index = index.modify(f); new_index(object, index) @@ -393,8 +382,7 @@ comptime fn modify_index(expr: Expr, f: fn[Env](Expr) -> Option) -> O } comptime fn modify_for(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_for().map(|expr: (Quoted, Expr, Expr)| { - let (identifier, array, body) = expr; + expr.as_for().map(|(identifier, array, body)| { let array = array.modify(f); let body = body.modify(f); new_for(identifier, array, body) @@ -402,8 +390,7 @@ comptime fn modify_for(expr: Expr, f: fn[Env](Expr) -> Option) -> Opt } comptime fn modify_for_range(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_for_range().map(|expr: (Quoted, Expr, Expr, Expr)| { - let (identifier, from, to, body) = expr; + expr.as_for_range().map(|(identifier, from, to, body)| { let from = from.modify(f); let to = to.modify(f); let body = body.modify(f); @@ -412,18 +399,15 @@ comptime fn modify_for_range(expr: Expr, f: fn[Env](Expr) -> Option) } comptime fn modify_lambda(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_lambda().map(|expr: ([(Expr, Option)], Option, Expr)| { - let (params, return_type, body) = expr; - let params = - params.map(|param: (Expr, Option)| (param.0.modify(f), param.1)); + expr.as_lambda().map(|(params, return_type, body)| { + let params = params.map(|(name, typ)| (name.modify(f), typ)); let body = body.modify(f); new_lambda(params, return_type, body) }) } comptime fn modify_let(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_let().map(|expr: (Expr, Option, Expr)| { - let (pattern, typ, expr) = expr; + expr.as_let().map(|(pattern, typ, expr)| { let pattern = pattern.modify(f); let expr = expr.modify(f); new_let(pattern, typ, expr) @@ -434,18 +418,16 @@ comptime fn modify_member_access( expr: Expr, f: fn[Env](Expr) -> Option, ) -> Option { - expr.as_member_access().map(|expr: (Expr, Quoted)| { - let (object, name) = expr; + expr.as_member_access().map(|(object, name)| { let object = object.modify(f); new_member_access(object, name) }) } comptime fn modify_method_call(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_method_call().map(|expr: (Expr, Quoted, [UnresolvedType], [Expr])| { - let (object, name, generics, arguments) = expr; + expr.as_method_call().map(|(object, name, generics, arguments)| { let object = object.modify(f); - let arguments = arguments.map(|arg: Expr| arg.modify(f)); + let arguments = arguments.map(|arg| arg.modify(f)); new_method_call(object, name, generics, arguments) }) } @@ -454,8 +436,7 @@ comptime fn modify_repeated_element_array( expr: Expr, f: fn[Env](Expr) -> Option, ) -> Option { - expr.as_repeated_element_array().map(|expr: (Expr, Expr)| { - let (expr, length) = expr; + expr.as_repeated_element_array().map(|(expr, length)| { let expr = expr.modify(f); let length = length.modify(f); new_repeated_element_array(expr, length) @@ -466,8 +447,7 @@ comptime fn modify_repeated_element_slice( expr: Expr, f: fn[Env](Expr) -> Option, ) -> Option { - expr.as_repeated_element_slice().map(|expr: (Expr, Expr)| { - let (expr, length) = expr; + expr.as_repeated_element_slice().map(|(expr, length)| { let expr = expr.modify(f); let length = length.modify(f); new_repeated_element_slice(expr, length) @@ -475,36 +455,35 @@ comptime fn modify_repeated_element_slice( } comptime fn modify_slice(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_slice().map(|exprs: [Expr]| { + expr.as_slice().map(|exprs| { let exprs = modify_expressions(exprs, f); new_slice(exprs) }) } comptime fn modify_tuple(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_tuple().map(|exprs: [Expr]| { + expr.as_tuple().map(|exprs| { let exprs = modify_expressions(exprs, f); new_tuple(exprs) }) } comptime fn modify_unary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_unary_op().map(|expr: (UnaryOp, Expr)| { - let (op, rhs) = expr; + expr.as_unary_op().map(|(op, rhs)| { let rhs = rhs.modify(f); new_unary_op(op, rhs) }) } comptime fn modify_unsafe(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_unsafe().map(|exprs: [Expr]| { - let exprs = exprs.map(|expr: Expr| expr.modify(f)); + expr.as_unsafe().map(|exprs| { + let exprs = exprs.map(|expr| expr.modify(f)); new_unsafe(exprs) }) } comptime fn modify_expressions(exprs: [Expr], f: fn[Env](Expr) -> Option) -> [Expr] { - exprs.map(|expr: Expr| expr.modify(f)) + exprs.map(|expr| expr.modify(f)) } comptime fn new_array(exprs: [Expr]) -> Expr { @@ -554,12 +533,7 @@ comptime fn new_comptime(exprs: [Expr]) -> Expr { } comptime fn new_constructor(typ: UnresolvedType, fields: [(Quoted, Expr)]) -> Expr { - let fields = fields - .map(|field: (Quoted, Expr)| { - let (name, value) = field; - quote { $name: $value } - }) - .join(quote { , }); + let fields = fields.map(|(name, value)| quote { $name: $value }).join(quote { , }); quote { $typ { $fields }}.as_expr().unwrap() } @@ -590,8 +564,7 @@ comptime fn new_lambda( body: Expr, ) -> Expr { let params = params - .map(|param: (Expr, Option)| { - let (name, typ) = param; + .map(|(name, typ)| { if typ.is_some() { let typ = typ.unwrap(); quote { $name: $typ } @@ -678,5 +651,5 @@ comptime fn new_unsafe(exprs: [Expr]) -> Expr { } comptime fn join_expressions(exprs: [Expr], separator: Quoted) -> Quoted { - exprs.map(|expr: Expr| expr.quoted()).join(separator) + exprs.map(|expr| expr.quoted()).join(separator) } diff --git a/noir_stdlib/src/meta/mod.nr b/noir_stdlib/src/meta/mod.nr index 21c1b14cc96..046e62d661c 100644 --- a/noir_stdlib/src/meta/mod.nr +++ b/noir_stdlib/src/meta/mod.nr @@ -101,10 +101,7 @@ pub comptime fn make_trait_impl( let where_clause = s.generics().map(|name| quote { $name: $trait_name }).join(quote {,}); // `for_each_field(field1) $join_fields_with for_each_field(field2) $join_fields_with ...` - let fields = s.fields_as_written().map(|f: (Quoted, Type)| { - let name = f.0; - for_each_field(name) - }); + let fields = s.fields_as_written().map(|(name, _)| for_each_field(name)); let body = body(fields.join(join_fields_with)); quote { From d248b3da7db201e8cab5f8b33e0875ccfd170c26 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 10:37:10 -0300 Subject: [PATCH 04/15] Simplify some lambdas in test programs --- .../inject_context_attribute/src/main.nr | 23 ++++++++----------- .../trait_generics/src/main.nr | 2 +- .../unquote_struct/src/main.nr | 9 +------- .../comptime_expr/src/main.nr | 3 +-- 4 files changed, 13 insertions(+), 24 deletions(-) diff --git a/test_programs/compile_success_empty/inject_context_attribute/src/main.nr b/test_programs/compile_success_empty/inject_context_attribute/src/main.nr index 963d4cea969..e682ea34b23 100644 --- a/test_programs/compile_success_empty/inject_context_attribute/src/main.nr +++ b/test_programs/compile_success_empty/inject_context_attribute/src/main.nr @@ -40,19 +40,16 @@ comptime fn inject_context(f: FunctionDefinition) { } comptime fn mapping_function(expr: Expr, f: FunctionDefinition) -> Option { - expr.as_function_call().and_then(|func_call: (Expr, [Expr])| { - let (name, arguments) = func_call; - name.resolve(Option::some(f)).as_function_definition().and_then( - |function_definition: FunctionDefinition| { - if function_definition.has_named_attribute("inject_context") { - let arguments = arguments.push_front(quote { _context }.as_expr().unwrap()); - let arguments = arguments.map(|arg: Expr| arg.quoted()).join(quote { , }); - Option::some(quote { $name($arguments) }.as_expr().unwrap()) - } else { - Option::none() - } - }, - ) + expr.as_function_call().and_then(|(name, arguments)| { + name.resolve(Option::some(f)).as_function_definition().and_then(|function_definition| { + if function_definition.has_named_attribute("inject_context") { + let arguments = arguments.push_front(quote { _context }.as_expr().unwrap()); + let arguments = arguments.map(|arg| arg.quoted()).join(quote { , }); + Option::some(quote { $name($arguments) }.as_expr().unwrap()) + } else { + Option::none() + } + }) }) } diff --git a/test_programs/compile_success_empty/trait_generics/src/main.nr b/test_programs/compile_success_empty/trait_generics/src/main.nr index 08302ded68c..e8b57b6fe6f 100644 --- a/test_programs/compile_success_empty/trait_generics/src/main.nr +++ b/test_programs/compile_success_empty/trait_generics/src/main.nr @@ -24,7 +24,7 @@ where T: MyInto, { fn into(self) -> [U; N] { - self.map(|x: T| x.into()) + self.map(|x| x.into()) } } diff --git a/test_programs/compile_success_empty/unquote_struct/src/main.nr b/test_programs/compile_success_empty/unquote_struct/src/main.nr index d4ab275858c..12c683a94a8 100644 --- a/test_programs/compile_success_empty/unquote_struct/src/main.nr +++ b/test_programs/compile_success_empty/unquote_struct/src/main.nr @@ -10,14 +10,7 @@ fn foo(x: Field, y: u32) -> u32 { // Given a function, wrap its parameters in a struct definition comptime fn output_struct(f: FunctionDefinition) -> Quoted { - let fields = f - .parameters() - .map(|param: (Quoted, Type)| { - let name = param.0; - let typ = param.1; - quote { $name: $typ, } - }) - .join(quote {}); + let fields = f.parameters().map(|(name, typ)| quote { $name: $typ, }).join(quote {}); quote { struct Foo { $fields } diff --git a/test_programs/noir_test_success/comptime_expr/src/main.nr b/test_programs/noir_test_success/comptime_expr/src/main.nr index 25910685e87..6efbc212cbe 100644 --- a/test_programs/noir_test_success/comptime_expr/src/main.nr +++ b/test_programs/noir_test_success/comptime_expr/src/main.nr @@ -761,8 +761,7 @@ mod tests { } comptime fn times_two(expr: Expr) -> Option { - expr.as_integer().and_then(|integer: (Field, bool)| { - let (value, _) = integer; + expr.as_integer().and_then(|(value, _)| { let value = value * 2; quote { $value }.as_expr() }) From 86754e29c136383bed8f922fd830045f521f1eb5 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 10:51:39 -0300 Subject: [PATCH 05/15] Do the same thing with regular calls --- .../src/elaborator/expressions.rs | 78 ++++++++++++------- compiler/noirc_frontend/src/tests.rs | 20 ++++- 2 files changed, 67 insertions(+), 31 deletions(-) diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 72713f5ae9d..fee0ce720cf 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -389,18 +389,25 @@ impl<'context> Elaborator<'context> { let (func, func_type) = self.elaborate_expression(*call.func); let mut arguments = Vec::with_capacity(call.arguments.len()); - let args = vecmap(call.arguments, |arg| { - let span = arg.span; + let args: Vec<_> = call + .arguments + .into_iter() + .enumerate() + .map(|(arg_index, arg)| { + let span = arg.span; - let (arg, typ) = if call.is_macro_call { - self.elaborate_in_comptime_context(|this| this.elaborate_expression(arg)) - } else { - self.elaborate_expression(arg) - }; + let (arg, typ) = if call.is_macro_call { + self.elaborate_in_comptime_context(|this| { + this.elaborate_call_argument_expression(arg, arg_index, &func_type) + }) + } else { + self.elaborate_call_argument_expression(arg, arg_index, &func_type) + }; - arguments.push(arg); - (typ, arg, span) - }); + arguments.push(arg); + (typ, arg, span) + }) + .collect(); // Avoid cloning arguments unless this is a macro call let mut comptime_args = Vec::new(); @@ -492,26 +499,8 @@ impl<'context> Elaborator<'context> { for (arg_index, arg) in method_call.arguments.into_iter().enumerate() { let span = arg.span; - let (arg, typ) = if let ExpressionKind::Lambda(lambda) = arg.kind { - let type_hint = if let Type::Function(func_args, _, _, _) = &func_type { - if let Some(Type::Function(func_args, _, _, _)) = - func_args.get(arg_index + 1) - { - Some(func_args) - } else { - None - } - } else { - None - }; - let (hir_expr, typ) = self.elaborate_lambda(*lambda, type_hint); - let id = self.interner.push_expr(hir_expr); - self.interner.push_expr_location(id, span, self.file); - self.interner.push_expr_type(id, typ.clone()); - (id, typ) - } else { - self.elaborate_expression(arg) - }; + let (arg, typ) = + self.elaborate_call_argument_expression(arg, arg_index + 1, &func_type); arguments.push(arg); function_args.push((typ, arg, span)); } @@ -552,6 +541,35 @@ impl<'context> Elaborator<'context> { } } + /// Elaborates an expression taking into account that it's a call argument in a function + /// that has the given type, and `arg_index` is the index of that argument in that function type. + fn elaborate_call_argument_expression( + &mut self, + arg: Expression, + arg_index: usize, + func_type: &Type, + ) -> (ExprId, Type) { + let ExpressionKind::Lambda(lambda) = arg.kind else { + return self.elaborate_expression(arg); + }; + + let span = arg.span; + let type_hint = if let Type::Function(func_args, _, _, _) = func_type { + if let Some(Type::Function(func_args, _, _, _)) = func_args.get(arg_index) { + Some(func_args) + } else { + None + } + } else { + None + }; + let (hir_expr, typ) = self.elaborate_lambda(*lambda, type_hint); + let id = self.interner.push_expr(hir_expr); + self.interner.push_expr_location(id, span, self.file); + self.interner.push_expr_type(id, typ.clone()); + (id, typ) + } + fn check_method_call_visibility(&mut self, func_id: FuncId, object_type: &Type, name: &Ident) { if !method_call_is_visible( object_type, diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 33bf8df126e..aafd5135ae2 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -3980,7 +3980,7 @@ fn checks_visibility_of_trait_related_to_trait_impl_on_method_call() { } #[test] -fn infers_lambda_argument_from_call_function_type() { +fn infers_lambda_argument_from_method_call_function_type() { let src = r#" struct Foo { value: Field, @@ -4009,3 +4009,21 @@ fn infers_lambda_argument_from_call_function_type() { "#; assert_no_errors(src); } + +#[test] +fn infers_lambda_argument_from_call_function_type() { + let src = r#" + struct Foo { + value: Field, + } + + fn call(f: fn(Foo) -> Field) -> Field { + f(Foo { value: 1 }) + } + + fn main() { + let _ = call(|foo| foo.value); + } + "#; + assert_no_errors(src); +} From 0eabf96260936c6060f6c1e33a84ed24e3fef846 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 11:09:26 -0300 Subject: [PATCH 06/15] Eagerly try to unify call arguments so lambdas work better --- compiler/noirc_frontend/src/ast/expression.rs | 4 ++ .../src/elaborator/expressions.rs | 44 +++++++++++++++---- compiler/noirc_frontend/src/tests.rs | 22 +++++++++- 3 files changed, 60 insertions(+), 10 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 9d521545e7a..c052527442e 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -223,6 +223,10 @@ impl ExpressionKind { struct_type: None, })) } + + pub fn is_lambda(&self) -> bool { + matches!(self, ExpressionKind::Lambda(..)) + } } impl Recoverable for ExpressionKind { diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index fee0ce720cf..a74cd6be70b 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -388,6 +388,8 @@ impl<'context> Elaborator<'context> { fn elaborate_call(&mut self, call: CallExpression, span: Span) -> (HirExpression, Type) { let (func, func_type) = self.elaborate_expression(*call.func); + let any_argument_is_lambda = call.arguments.iter().any(|arg| arg.kind.is_lambda()); + let mut arguments = Vec::with_capacity(call.arguments.len()); let args: Vec<_> = call .arguments @@ -404,6 +406,16 @@ impl<'context> Elaborator<'context> { self.elaborate_call_argument_expression(arg, arg_index, &func_type) }; + if any_argument_is_lambda { + // Try to unify this argument type against the function's argument type + // so that a potential lambda following this argument can have more concrete types. + if let Type::Function(func_args, _, _, _) = &func_type { + if let Some(func_arg_type) = func_args.get(arg_index) { + let _ = func_arg_type.unify(&typ); + } + } + } + arguments.push(arg); (typ, arg, span) }) @@ -479,14 +491,19 @@ impl<'context> Elaborator<'context> { self.type_check_variable(function_name.clone(), function_id, generics.clone()); self.interner.push_expr_type(function_id, func_type.clone()); - // Try to unify the object type with the first argument of the function. - // The reason to do this is that many methods that take a lambda will yield `self` or part of `self` - // as a parameter. By unifying `self` with the first argument we'll potentially get more - // concrete types in the arguments that are function types, which will later be passed as - // lambda parameter hints. - if let Type::Function(args, _, _, _) = &func_type { - if !args.is_empty() { - let _ = args[0].unify(&object_type); + let any_argument_is_lambda = + method_call.arguments.iter().any(|arg| arg.kind.is_lambda()); + + if any_argument_is_lambda { + // Try to unify the object type with the first argument of the function. + // The reason to do this is that many methods that take a lambda will yield `self` or part of `self` + // as a parameter. By unifying `self` with the first argument we'll potentially get more + // concrete types in the arguments that are function types, which will later be passed as + // lambda parameter hints. + if let Type::Function(args, _, _, _) = &func_type { + if !args.is_empty() { + let _ = args[0].unify(&object_type); + } } } @@ -501,6 +518,17 @@ impl<'context> Elaborator<'context> { let span = arg.span; let (arg, typ) = self.elaborate_call_argument_expression(arg, arg_index + 1, &func_type); + + if any_argument_is_lambda { + // Try to unify this argument type against the function's argument type + // so that a potential lambda following this argument can have more concrete types. + if let Type::Function(func_args, _, _, _) = &func_type { + if let Some(func_arg_type) = func_args.get(arg_index + 1) { + let _ = func_arg_type.unify(&typ); + } + } + } + arguments.push(arg); function_args.push((typ, arg, span)); } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index aafd5135ae2..5eaeb43ec8f 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -3015,13 +3015,13 @@ fn do_not_eagerly_error_on_cast_on_type_variable() { #[test] fn error_on_cast_over_type_variable() { let src = r#" - pub fn foo(x: T, f: fn(T) -> U) -> U { + pub fn foo(f: fn(T) -> U, x: T, ) -> U { f(x) } fn main() { let x = "a"; - let _: Field = foo(x, |x| x as Field); + let _: Field = foo(|x| x as Field, x); } "#; @@ -4027,3 +4027,21 @@ fn infers_lambda_argument_from_call_function_type() { "#; assert_no_errors(src); } + +#[test] +fn infers_lambda_argument_from_call_function_type_in_generic_call() { + let src = r#" + struct Foo { + value: Field, + } + + fn call(t: T, f: fn(T) -> Field) -> Field { + f(t) + } + + fn main() { + let _ = call(Foo { value: 1 }, |foo| foo.value); + } + "#; + assert_no_errors(src); +} From 5b1e0132aac57fe0f7c6f06c21db78b4484b38b5 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 11:29:26 -0300 Subject: [PATCH 07/15] noir-edwards compiles again --- .../{.failures.jsonl.does_not_compile => .failures.jsonl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename .github/critical_libraries_status/noir-lang/noir-edwards/{.failures.jsonl.does_not_compile => .failures.jsonl} (100%) diff --git a/.github/critical_libraries_status/noir-lang/noir-edwards/.failures.jsonl.does_not_compile b/.github/critical_libraries_status/noir-lang/noir-edwards/.failures.jsonl similarity index 100% rename from .github/critical_libraries_status/noir-lang/noir-edwards/.failures.jsonl.does_not_compile rename to .github/critical_libraries_status/noir-lang/noir-edwards/.failures.jsonl From 79f407668eb75a48e4d3cbda45a17b3c4037b4fc Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 11:43:44 -0300 Subject: [PATCH 08/15] Only unify parameters without a type annotation --- .../src/elaborator/expressions.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index a74cd6be70b..ed7ed66119c 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -944,15 +944,18 @@ impl<'context> Elaborator<'context> { .enumerate() .map(|(index, (pattern, typ))| { let parameter = DefinitionKind::Local(None); + let is_unspecified = matches!(typ.typ, UnresolvedTypeData::Unspecified); let typ = self.resolve_inferred_type(typ); - // If there's a parameter type hint, use it to unify the argument type - if let Some(parameter_type_hint) = - parameters_type_hints.and_then(|hints| hints.get(index)) - { - // We don't error here because eventually the lambda type will be checked against - // the call that contains it, which would then produce an error if this didn't unify. - let _ = typ.unify(parameter_type_hint); + if is_unspecified { + // If there's a parameter type hint, use it to unify the argument type + if let Some(parameter_type_hint) = + parameters_type_hints.and_then(|hints| hints.get(index)) + { + // We don't error here because eventually the lambda type will be checked against + // the call that contains it, which would then produce an error if this didn't unify. + let _ = typ.unify(parameter_type_hint); + } } arg_types.push(typ.clone()); From 2cffe19f6f098d7b23c81c2dcc1e0340033292e5 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 11:50:30 -0300 Subject: [PATCH 09/15] Only infer if lambdas don't have types --- compiler/noirc_frontend/src/ast/expression.rs | 8 +++- compiler/noirc_frontend/src/ast/mod.rs | 4 ++ .../src/elaborator/expressions.rs | 42 ++++++++++++++----- 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index c052527442e..ca6090a8776 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -224,8 +224,12 @@ impl ExpressionKind { })) } - pub fn is_lambda(&self) -> bool { - matches!(self, ExpressionKind::Lambda(..)) + pub fn is_lambda_without_type_annotations(&self) -> bool { + if let ExpressionKind::Lambda(lambda) = self { + lambda.parameters.iter().any(|(_, typ)| typ.typ.is_unspecified()) + } else { + false + } } } diff --git a/compiler/noirc_frontend/src/ast/mod.rs b/compiler/noirc_frontend/src/ast/mod.rs index f8a82574bee..b4e3e06cc6f 100644 --- a/compiler/noirc_frontend/src/ast/mod.rs +++ b/compiler/noirc_frontend/src/ast/mod.rs @@ -448,6 +448,10 @@ impl UnresolvedTypeData { | UnresolvedTypeData::Error => false, } } + + pub fn is_unspecified(&self) -> bool { + matches!(self, UnresolvedTypeData::Unspecified) + } } #[derive(Debug, PartialEq, Eq, Copy, Clone, Hash, PartialOrd, Ord)] diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index ed7ed66119c..a3506a11689 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -388,7 +388,8 @@ impl<'context> Elaborator<'context> { fn elaborate_call(&mut self, call: CallExpression, span: Span) -> (HirExpression, Type) { let (func, func_type) = self.elaborate_expression(*call.func); - let any_argument_is_lambda = call.arguments.iter().any(|arg| arg.kind.is_lambda()); + let any_argument_has_lambda_without_type_annotations = + call.arguments.iter().any(|arg| arg.kind.is_lambda_without_type_annotations()); let mut arguments = Vec::with_capacity(call.arguments.len()); let args: Vec<_> = call @@ -400,13 +401,23 @@ impl<'context> Elaborator<'context> { let (arg, typ) = if call.is_macro_call { self.elaborate_in_comptime_context(|this| { - this.elaborate_call_argument_expression(arg, arg_index, &func_type) + this.elaborate_call_argument_expression( + arg, + arg_index, + &func_type, + any_argument_has_lambda_without_type_annotations, + ) }) } else { - self.elaborate_call_argument_expression(arg, arg_index, &func_type) + self.elaborate_call_argument_expression( + arg, + arg_index, + &func_type, + any_argument_has_lambda_without_type_annotations, + ) }; - if any_argument_is_lambda { + if any_argument_has_lambda_without_type_annotations { // Try to unify this argument type against the function's argument type // so that a potential lambda following this argument can have more concrete types. if let Type::Function(func_args, _, _, _) = &func_type { @@ -491,10 +502,12 @@ impl<'context> Elaborator<'context> { self.type_check_variable(function_name.clone(), function_id, generics.clone()); self.interner.push_expr_type(function_id, func_type.clone()); - let any_argument_is_lambda = - method_call.arguments.iter().any(|arg| arg.kind.is_lambda()); + let any_argument_has_lambda_without_type_annotations = method_call + .arguments + .iter() + .any(|arg| arg.kind.is_lambda_without_type_annotations()); - if any_argument_is_lambda { + if any_argument_has_lambda_without_type_annotations { // Try to unify the object type with the first argument of the function. // The reason to do this is that many methods that take a lambda will yield `self` or part of `self` // as a parameter. By unifying `self` with the first argument we'll potentially get more @@ -516,10 +529,14 @@ impl<'context> Elaborator<'context> { for (arg_index, arg) in method_call.arguments.into_iter().enumerate() { let span = arg.span; - let (arg, typ) = - self.elaborate_call_argument_expression(arg, arg_index + 1, &func_type); + let (arg, typ) = self.elaborate_call_argument_expression( + arg, + arg_index + 1, + &func_type, + any_argument_has_lambda_without_type_annotations, + ); - if any_argument_is_lambda { + if any_argument_has_lambda_without_type_annotations { // Try to unify this argument type against the function's argument type // so that a potential lambda following this argument can have more concrete types. if let Type::Function(func_args, _, _, _) = &func_type { @@ -576,7 +593,12 @@ impl<'context> Elaborator<'context> { arg: Expression, arg_index: usize, func_type: &Type, + any_argument_has_lambda_without_type_annotations: bool, ) -> (ExprId, Type) { + if !any_argument_has_lambda_without_type_annotations { + return self.elaborate_expression(arg); + } + let ExpressionKind::Lambda(lambda) = arg.kind else { return self.elaborate_expression(arg); }; From 1600673ac9e14c7d5dc01c1f2fd463f248c601a1 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 14:11:19 -0300 Subject: [PATCH 10/15] Apply suggestions from code review --- compiler/noirc_frontend/src/ast/expression.rs | 8 - .../src/elaborator/expressions.rs | 159 ++++++------------ 2 files changed, 56 insertions(+), 111 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index ca6090a8776..9d521545e7a 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -223,14 +223,6 @@ impl ExpressionKind { struct_type: None, })) } - - pub fn is_lambda_without_type_annotations(&self) -> bool { - if let ExpressionKind::Lambda(lambda) = self { - lambda.parameters.iter().any(|(_, typ)| typ.typ.is_unspecified()) - } else { - false - } - } } impl Recoverable for ExpressionKind { diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index a3506a11689..e9b990d3e0e 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -387,50 +387,33 @@ impl<'context> Elaborator<'context> { fn elaborate_call(&mut self, call: CallExpression, span: Span) -> (HirExpression, Type) { let (func, func_type) = self.elaborate_expression(*call.func); - - let any_argument_has_lambda_without_type_annotations = - call.arguments.iter().any(|arg| arg.kind.is_lambda_without_type_annotations()); + let func_arg_types = + if let Type::Function(args, _, _, _) = &func_type { Some(args) } else { None }; let mut arguments = Vec::with_capacity(call.arguments.len()); - let args: Vec<_> = call - .arguments - .into_iter() - .enumerate() - .map(|(arg_index, arg)| { - let span = arg.span; - - let (arg, typ) = if call.is_macro_call { - self.elaborate_in_comptime_context(|this| { - this.elaborate_call_argument_expression( - arg, - arg_index, - &func_type, - any_argument_has_lambda_without_type_annotations, - ) - }) - } else { - self.elaborate_call_argument_expression( - arg, - arg_index, - &func_type, - any_argument_has_lambda_without_type_annotations, - ) - }; - - if any_argument_has_lambda_without_type_annotations { - // Try to unify this argument type against the function's argument type - // so that a potential lambda following this argument can have more concrete types. - if let Type::Function(func_args, _, _, _) = &func_type { - if let Some(func_arg_type) = func_args.get(arg_index) { - let _ = func_arg_type.unify(&typ); - } - } + let args = vecmap(call.arguments.into_iter().enumerate(), |(arg_index, arg)| { + let span = arg.span; + let expected_type = func_arg_types.and_then(|args| args.get(arg_index)); + + let (arg, typ) = if call.is_macro_call { + self.elaborate_in_comptime_context(|this| { + this.elaborate_expression_with_type(arg, expected_type) + }) + } else { + self.elaborate_expression_with_type(arg, expected_type) + }; + + // Try to unify this argument type against the function's argument type + // so that a potential lambda following this argument can have more concrete types. + if let Type::Function(func_args, _, _, _) = &func_type { + if let Some(func_arg_type) = func_args.get(arg_index) { + let _ = func_arg_type.unify(&typ); } + } - arguments.push(arg); - (typ, arg, span) - }) - .collect(); + arguments.push(arg); + (typ, arg, span) + }); // Avoid cloning arguments unless this is a macro call let mut comptime_args = Vec::new(); @@ -502,21 +485,17 @@ impl<'context> Elaborator<'context> { self.type_check_variable(function_name.clone(), function_id, generics.clone()); self.interner.push_expr_type(function_id, func_type.clone()); - let any_argument_has_lambda_without_type_annotations = method_call - .arguments - .iter() - .any(|arg| arg.kind.is_lambda_without_type_annotations()); - - if any_argument_has_lambda_without_type_annotations { - // Try to unify the object type with the first argument of the function. - // The reason to do this is that many methods that take a lambda will yield `self` or part of `self` - // as a parameter. By unifying `self` with the first argument we'll potentially get more - // concrete types in the arguments that are function types, which will later be passed as - // lambda parameter hints. - if let Type::Function(args, _, _, _) = &func_type { - if !args.is_empty() { - let _ = args[0].unify(&object_type); - } + let func_arg_types = + if let Type::Function(args, _, _, _) = &func_type { Some(args) } else { None }; + + // Try to unify the object type with the first argument of the function. + // The reason to do this is that many methods that take a lambda will yield `self` or part of `self` + // as a parameter. By unifying `self` with the first argument we'll potentially get more + // concrete types in the arguments that are function types, which will later be passed as + // lambda parameter hints. + if let Type::Function(args, _, _, _) = &func_type { + if !args.is_empty() { + let _ = args[0].unify(&object_type); } } @@ -529,20 +508,14 @@ impl<'context> Elaborator<'context> { for (arg_index, arg) in method_call.arguments.into_iter().enumerate() { let span = arg.span; - let (arg, typ) = self.elaborate_call_argument_expression( - arg, - arg_index + 1, - &func_type, - any_argument_has_lambda_without_type_annotations, - ); + let expected_type = func_arg_types.and_then(|args| args.get(arg_index + 1)); + let (arg, typ) = self.elaborate_expression_with_type(arg, expected_type); - if any_argument_has_lambda_without_type_annotations { - // Try to unify this argument type against the function's argument type - // so that a potential lambda following this argument can have more concrete types. - if let Type::Function(func_args, _, _, _) = &func_type { - if let Some(func_arg_type) = func_args.get(arg_index + 1) { - let _ = func_arg_type.unify(&typ); - } + // Try to unify this argument type against the function's argument type + // so that a potential lambda following this argument can have more concrete types. + if let Type::Function(func_args, _, _, _) = &func_type { + if let Some(func_arg_type) = func_args.get(arg_index + 1) { + let _ = func_arg_type.unify(&typ); } } @@ -586,33 +559,19 @@ impl<'context> Elaborator<'context> { } } - /// Elaborates an expression taking into account that it's a call argument in a function - /// that has the given type, and `arg_index` is the index of that argument in that function type. - fn elaborate_call_argument_expression( + /// Elaborates an expression knowing that it has to match a given type. + fn elaborate_expression_with_type( &mut self, arg: Expression, - arg_index: usize, - func_type: &Type, - any_argument_has_lambda_without_type_annotations: bool, + typ: Option<&Type>, ) -> (ExprId, Type) { - if !any_argument_has_lambda_without_type_annotations { - return self.elaborate_expression(arg); - } - let ExpressionKind::Lambda(lambda) = arg.kind else { return self.elaborate_expression(arg); }; let span = arg.span; - let type_hint = if let Type::Function(func_args, _, _, _) = func_type { - if let Some(Type::Function(func_args, _, _, _)) = func_args.get(arg_index) { - Some(func_args) - } else { - None - } - } else { - None - }; + let type_hint = + if let Some(Type::Function(func_args, _, _, _)) = typ { Some(func_args) } else { None }; let (hir_expr, typ) = self.elaborate_lambda(*lambda, type_hint); let id = self.interner.push_expr(hir_expr); self.interner.push_expr_location(id, span, self.file); @@ -960,30 +919,24 @@ impl<'context> Elaborator<'context> { self.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index }); let mut arg_types = Vec::with_capacity(lambda.parameters.len()); - let parameters: Vec<_> = lambda - .parameters - .into_iter() - .enumerate() - .map(|(index, (pattern, typ))| { + let parameters = + vecmap(lambda.parameters.into_iter().enumerate(), |(index, (pattern, typ))| { let parameter = DefinitionKind::Local(None); - let is_unspecified = matches!(typ.typ, UnresolvedTypeData::Unspecified); - let typ = self.resolve_inferred_type(typ); - - if is_unspecified { - // If there's a parameter type hint, use it to unify the argument type + let typ = if let UnresolvedTypeData::Unspecified = typ.typ { if let Some(parameter_type_hint) = parameters_type_hints.and_then(|hints| hints.get(index)) { - // We don't error here because eventually the lambda type will be checked against - // the call that contains it, which would then produce an error if this didn't unify. - let _ = typ.unify(parameter_type_hint); + parameter_type_hint.clone() + } else { + self.interner.next_type_variable_with_kind(Kind::Any) } - } + } else { + self.resolve_type(typ) + }; arg_types.push(typ.clone()); (self.elaborate_pattern(pattern, typ.clone(), parameter, true), typ) - }) - .collect(); + }); let return_type = self.resolve_inferred_type(lambda.return_type); let body_span = lambda.body.span; From 9ba7b7c3db4dd8a7943599562605d99f975bcf60 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 14:12:19 -0300 Subject: [PATCH 11/15] Remove unused method --- compiler/noirc_frontend/src/ast/mod.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/mod.rs b/compiler/noirc_frontend/src/ast/mod.rs index b4e3e06cc6f..f8a82574bee 100644 --- a/compiler/noirc_frontend/src/ast/mod.rs +++ b/compiler/noirc_frontend/src/ast/mod.rs @@ -448,10 +448,6 @@ impl UnresolvedTypeData { | UnresolvedTypeData::Error => false, } } - - pub fn is_unspecified(&self) -> bool { - matches!(self, UnresolvedTypeData::Unspecified) - } } #[derive(Debug, PartialEq, Eq, Copy, Clone, Hash, PartialOrd, Ord)] From 19d1b0c7a7f88669e14e9f29be6f4db7babce06d Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 16 Jan 2025 14:14:45 -0300 Subject: [PATCH 12/15] Use `expected_type` and `func_arg_types` --- .../src/elaborator/expressions.rs | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index e9b990d3e0e..ef2ae9c4df0 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -405,10 +405,8 @@ impl<'context> Elaborator<'context> { // Try to unify this argument type against the function's argument type // so that a potential lambda following this argument can have more concrete types. - if let Type::Function(func_args, _, _, _) = &func_type { - if let Some(func_arg_type) = func_args.get(arg_index) { - let _ = func_arg_type.unify(&typ); - } + if let Some(expected_type) = expected_type { + let _ = expected_type.unify(&typ); } arguments.push(arg); @@ -493,10 +491,8 @@ impl<'context> Elaborator<'context> { // as a parameter. By unifying `self` with the first argument we'll potentially get more // concrete types in the arguments that are function types, which will later be passed as // lambda parameter hints. - if let Type::Function(args, _, _, _) = &func_type { - if !args.is_empty() { - let _ = args[0].unify(&object_type); - } + if let Some(first_arg_type) = func_arg_types.and_then(|args| args.first()) { + let _ = first_arg_type.unify(&object_type); } // These arguments will be given to the desugared function call. @@ -513,10 +509,8 @@ impl<'context> Elaborator<'context> { // Try to unify this argument type against the function's argument type // so that a potential lambda following this argument can have more concrete types. - if let Type::Function(func_args, _, _, _) = &func_type { - if let Some(func_arg_type) = func_args.get(arg_index + 1) { - let _ = func_arg_type.unify(&typ); - } + if let Some(expected_type) = expected_type { + let _ = expected_type.unify(&typ); } arguments.push(arg); From 60d4cc28a7336ce4027b7c0e8c3e044fe2cfc2a9 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Fri, 17 Jan 2025 11:04:02 -0300 Subject: [PATCH 13/15] Keep track of whether an InfixExpr is an inversion --- compiler/noirc_frontend/src/elaborator/mod.rs | 4 +- .../noirc_frontend/src/elaborator/types.rs | 2 +- .../src/hir/comptime/hir_to_display_ast.rs | 2 +- compiler/noirc_frontend/src/hir_def/types.rs | 88 ++++++++++++++----- .../src/hir_def/types/arithmetic.rs | 50 ++++++----- .../src/monomorphization/mod.rs | 2 +- compiler/noirc_frontend/src/tests.rs | 22 +++++ tooling/lsp/src/requests/completion.rs | 2 +- tooling/lsp/src/requests/hover.rs | 2 +- .../src/trait_impl_method_stub_generator.rs | 2 +- 10 files changed, 123 insertions(+), 53 deletions(-) diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index d3dded22ab4..5299d9f5653 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -1043,7 +1043,7 @@ impl<'context> Elaborator<'context> { Type::MutableReference(typ) => { self.mark_type_as_used(typ); } - Type::InfixExpr(left, _op, right) => { + Type::InfixExpr(left, _op, right, _) => { self.mark_type_as_used(left); self.mark_type_as_used(right); } @@ -1688,7 +1688,7 @@ impl<'context> Elaborator<'context> { Type::MutableReference(typ) | Type::Array(_, typ) | Type::Slice(typ) => { self.check_type_is_not_more_private_then_item(name, visibility, typ, span); } - Type::InfixExpr(left, _op, right) => { + Type::InfixExpr(left, _op, right, _) => { self.check_type_is_not_more_private_then_item(name, visibility, left, span); self.check_type_is_not_more_private_then_item(name, visibility, right, span); } diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 8fa0b210605..a1b63910a3e 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -535,7 +535,7 @@ impl<'context> Elaborator<'context> { } } (lhs, rhs) => { - let infix = Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)); + let infix = Type::infix_expr(Box::new(lhs), op, Box::new(rhs)); Type::CheckedCast { from: Box::new(infix.clone()), to: Box::new(infix) } .canonicalize() } diff --git a/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs b/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs index 9338c0fc37f..86c5ab5fd18 100644 --- a/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs +++ b/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs @@ -359,7 +359,7 @@ impl Type { Type::Constant(..) => panic!("Type::Constant where a type was expected: {self:?}"), Type::Quoted(quoted_type) => UnresolvedTypeData::Quoted(*quoted_type), Type::Error => UnresolvedTypeData::Error, - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, _) => { let lhs = Box::new(lhs.to_type_expression()); let rhs = Box::new(rhs.to_type_expression()); let span = Span::default(); diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index c0dbf6f9500..03dba47b329 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -97,10 +97,7 @@ pub enum Type { /// A cast (to, from) that's checked at monomorphization. /// /// Simplifications on arithmetic generics are only allowed on the LHS. - CheckedCast { - from: Box, - to: Box, - }, + CheckedCast { from: Box, to: Box }, /// A functions with arguments, a return type and environment. /// the environment should be `Unit` by default, @@ -132,7 +129,13 @@ pub enum Type { /// The type of quoted code in macros. This is always a comptime-only type Quoted(QuotedType), - InfixExpr(Box, BinaryTypeOperator, Box), + /// An infix expression in the form `lhs * rhs`. + /// + /// The `inversion` bool keeps track of whether this expression came from + /// an expression like `4 = a / b` which was transformed to `a = 4 / b` + /// so that if at some point a infix expression `b * (4 / b)` is created, + /// it could be simplified back to `4`. + InfixExpr(Box, BinaryTypeOperator, Box, bool /* inversion */), /// The result of some type error. Remembering type errors as their own type variant lets /// us avoid issuing repeat type errors for the same item. For example, a lambda with @@ -905,7 +908,7 @@ impl std::fmt::Display for Type { write!(f, "&mut {element}") } Type::Quoted(quoted) => write!(f, "{}", quoted), - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, _) => { let this = self.canonicalize_checked(); // Prevent infinite recursion @@ -1146,7 +1149,7 @@ impl Type { .into_iter() .all(|(_, field)| field.is_valid_for_program_input()), - Type::InfixExpr(lhs, _, rhs) => { + Type::InfixExpr(lhs, _, rhs, _) => { lhs.is_valid_for_program_input() && rhs.is_valid_for_program_input() } } @@ -1307,7 +1310,7 @@ impl Type { TypeBinding::Bound(ref typ) => typ.kind(), TypeBinding::Unbound(_, ref type_var_kind) => type_var_kind.clone(), }, - Type::InfixExpr(lhs, _op, rhs) => lhs.infix_kind(rhs), + Type::InfixExpr(lhs, _op, rhs, _) => lhs.infix_kind(rhs), Type::Alias(def, generics) => def.borrow().get_type(generics).kind(), // This is a concrete FieldElement, not an IntegerOrField Type::FieldElement @@ -1340,6 +1343,48 @@ impl Type { } } + /// Creates an `InfixExpr`. + pub fn infix_expr(lhs: Box, op: BinaryTypeOperator, rhs: Box) -> Type { + Self::new_infix_expr(lhs, op, rhs, false) + } + + /// Creates an `InfixExpr` that results from the compiler trying to unify something like + /// `4 = a * b` into `a = 4 / b` (where `4 / b` is the "inverted" expression). + pub fn inverted_infix_expr(lhs: Box, op: BinaryTypeOperator, rhs: Box) -> Type { + Self::new_infix_expr(lhs, op, rhs, true) + } + + pub fn new_infix_expr( + lhs: Box, + op: BinaryTypeOperator, + rhs: Box, + inversion: bool, + ) -> Type { + // If an InfixExpr like this is tried to be created: + // + // a * (b / a) + // + // where `b / a` resulted from the compiler creating an inverted InfixExpr from a previous + // unification (that is, the compiler had `b = a / y` and ended up doing `y = b / a` where + // `y` is `rhs` here) then we can simplify this to just `b` because there wasn't an actual + // division in the original expression, so multiplying it back is just going back to the + // original `y`.gg + if let Type::InfixExpr(rhs_lhs, rhs_op, rhs_rhs, true) = &*rhs { + if op.approx_inverse() == Some(*rhs_op) && lhs == *rhs_rhs { + return *rhs_lhs.clone(); + } + } + + // Same thing but on the other side. + if let Type::InfixExpr(lhs_lhs, lhs_op, lhs_rhs, true) = &*lhs { + if op.approx_inverse() == Some(*lhs_op) && rhs == *lhs_rhs { + return *lhs_lhs.clone(); + } + } + + Self::InfixExpr(lhs, op, rhs, inversion) + } + /// Returns the number of field elements required to represent the type once encoded. pub fn field_count(&self, location: &Location) -> u32 { match self { @@ -1697,7 +1742,7 @@ impl Type { elem_a.try_unify(elem_b, bindings) } - (InfixExpr(lhs_a, op_a, rhs_a), InfixExpr(lhs_b, op_b, rhs_b)) => { + (InfixExpr(lhs_a, op_a, rhs_a, _), InfixExpr(lhs_b, op_b, rhs_b, _)) => { if op_a == op_b { // We need to preserve the original bindings since if syntactic equality // fails we fall back to other equality strategies. @@ -1724,14 +1769,15 @@ impl Type { } else { Err(UnificationError) } - } else if let InfixExpr(lhs, op, rhs) = other { + } else if let InfixExpr(lhs, op, rhs, _) = other { if let Some(inverse) = op.approx_inverse() { // Handle cases like `4 = a + b` by trying to solve to `a = 4 - b` - let new_type = InfixExpr( + let new_type = Type::inverted_infix_expr( Box::new(Constant(*value, kind.clone())), inverse, rhs.clone(), ); + new_type.try_unify(lhs, bindings)?; Ok(()) } else { @@ -1937,7 +1983,7 @@ impl Type { }) } } - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, _) => { let infix_kind = lhs.infix_kind(&rhs); if kind.unifies(&infix_kind) { let lhs_value = lhs.evaluate_to_field_element_helper( @@ -2267,10 +2313,10 @@ impl Type { }); Type::TraitAsType(*s, name.clone(), TraitGenerics { ordered, named }) } - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, inversion) => { let lhs = lhs.substitute_helper(type_bindings, substitute_bound_typevars); let rhs = rhs.substitute_helper(type_bindings, substitute_bound_typevars); - Type::InfixExpr(Box::new(lhs), *op, Box::new(rhs)) + Type::InfixExpr(Box::new(lhs), *op, Box::new(rhs), *inversion) } Type::FieldElement @@ -2320,7 +2366,7 @@ impl Type { || env.occurs(target_id) } Type::MutableReference(element) => element.occurs(target_id), - Type::InfixExpr(lhs, _op, rhs) => lhs.occurs(target_id) || rhs.occurs(target_id), + Type::InfixExpr(lhs, _op, rhs, _) => lhs.occurs(target_id) || rhs.occurs(target_id), Type::FieldElement | Type::Integer(_, _) @@ -2389,10 +2435,10 @@ impl Type { }); TraitAsType(*s, name.clone(), TraitGenerics { ordered, named }) } - InfixExpr(lhs, op, rhs) => { + InfixExpr(lhs, op, rhs, inversion) => { let lhs = lhs.follow_bindings(); let rhs = rhs.follow_bindings(); - InfixExpr(Box::new(lhs), *op, Box::new(rhs)) + InfixExpr(Box::new(lhs), *op, Box::new(rhs), *inversion) } // Expect that this function should only be called on instantiated types @@ -2502,7 +2548,7 @@ impl Type { } Type::MutableReference(elem) => elem.replace_named_generics_with_type_variables(), Type::Forall(_, typ) => typ.replace_named_generics_with_type_variables(), - Type::InfixExpr(lhs, _op, rhs) => { + Type::InfixExpr(lhs, _op, rhs, _) => { lhs.replace_named_generics_with_type_variables(); rhs.replace_named_generics_with_type_variables(); } @@ -2544,7 +2590,7 @@ impl Type { TypeBinding::Unbound(_, kind) => kind.integral_maximum_size(), }, Type::MutableReference(typ) => typ.integral_maximum_size(), - Type::InfixExpr(lhs, _op, rhs) => lhs.infix_kind(rhs).integral_maximum_size(), + Type::InfixExpr(lhs, _op, rhs, _) => lhs.infix_kind(rhs).integral_maximum_size(), Type::Constant(_, kind) => kind.integral_maximum_size(), Type::Array(..) @@ -2913,7 +2959,7 @@ impl std::hash::Hash for Type { Type::CheckedCast { to, .. } => to.hash(state), Type::Constant(value, _) => value.hash(state), Type::Quoted(typ) => typ.hash(state), - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, _) => { lhs.hash(state); op.hash(state); rhs.hash(state); @@ -2982,7 +3028,7 @@ impl PartialEq for Type { lhs == rhs && lhs_kind == rhs_kind } (Quoted(lhs), Quoted(rhs)) => lhs == rhs, - (InfixExpr(l_lhs, l_op, l_rhs), InfixExpr(r_lhs, r_op, r_rhs)) => { + (InfixExpr(l_lhs, l_op, l_rhs, _), InfixExpr(r_lhs, r_op, r_rhs, _)) => { l_lhs == r_lhs && l_op == r_op && l_rhs == r_rhs } // Special case: we consider unbound named generics and type variables to be equal to each diff --git a/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs index 8cdf6f5502c..5750365c62d 100644 --- a/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs +++ b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs @@ -58,7 +58,7 @@ impl Type { run_simplifications: bool, ) -> Type { match self.follow_bindings() { - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, inversion) => { let kind = lhs.infix_kind(&rhs); let dummy_span = Span::default(); // evaluate_to_field_element also calls canonicalize so if we just called @@ -76,7 +76,7 @@ impl Type { let rhs = rhs.canonicalize_helper(found_checked_cast, run_simplifications); if !run_simplifications { - return Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)); + return Type::InfixExpr(Box::new(lhs), op, Box::new(rhs), inversion); } if let Some(result) = Self::try_simplify_non_constants_in_lhs(&lhs, op, &rhs) { @@ -97,7 +97,7 @@ impl Type { return Self::sort_commutative(&lhs, op, &rhs); } - Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)) + Type::InfixExpr(Box::new(lhs), op, Box::new(rhs), inversion) } Type::CheckedCast { from, to } => { let inner_found_checked_cast = true; @@ -131,7 +131,7 @@ impl Type { // Push each non-constant term to `sorted` to sort them. Recur on InfixExprs with the same operator. while let Some(item) = queue.pop() { match item.canonicalize_unchecked() { - Type::InfixExpr(lhs_inner, new_op, rhs_inner) if new_op == op => { + Type::InfixExpr(lhs_inner, new_op, rhs_inner, _) if new_op == op => { queue.push(*lhs_inner); queue.push(*rhs_inner); } @@ -157,18 +157,18 @@ impl Type { // - 1 since `typ` already is set to the first instance for _ in 0..first_type_count - 1 { - typ = Type::InfixExpr(Box::new(typ), op, Box::new(first.0.clone())); + typ = Type::infix_expr(Box::new(typ), op, Box::new(first.0.clone())); } for (rhs, rhs_count) in sorted { for _ in 0..rhs_count { - typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone())); + typ = Type::infix_expr(Box::new(typ), op, Box::new(rhs.clone())); } } if constant != zero_value { let constant = Type::Constant(constant, lhs.infix_kind(rhs)); - typ = Type::InfixExpr(Box::new(typ), op, Box::new(constant)); + typ = Type::infix_expr(Box::new(typ), op, Box::new(constant)); } typ @@ -192,11 +192,11 @@ impl Type { match lhs.follow_bindings() { Type::CheckedCast { from, to } => { // Apply operation directly to `from` while attempting simplification to `to`. - let from = Type::InfixExpr(from, op, Box::new(rhs.clone())); + let from = Type::infix_expr(from, op, Box::new(rhs.clone())); let to = Self::try_simplify_non_constants_in_lhs(&to, op, rhs)?; Some(Type::CheckedCast { from: Box::new(from), to: Box::new(to) }) } - Type::InfixExpr(l_lhs, l_op, l_rhs) => { + Type::InfixExpr(l_lhs, l_op, l_rhs, _) => { // Note that this is exact, syntactic equality, not unification. // `rhs` is expected to already be in canonical form. if l_op.approx_inverse() != Some(op) @@ -229,11 +229,11 @@ impl Type { match rhs.follow_bindings() { Type::CheckedCast { from, to } => { // Apply operation directly to `from` while attempting simplification to `to`. - let from = Type::InfixExpr(Box::new(lhs.clone()), op, from); + let from = Type::infix_expr(Box::new(lhs.clone()), op, from); let to = Self::try_simplify_non_constants_in_rhs(lhs, op, &to)?; Some(Type::CheckedCast { from: Box::new(from), to: Box::new(to) }) } - Type::InfixExpr(r_lhs, r_op, r_rhs) => { + Type::InfixExpr(r_lhs, r_op, r_rhs, _) => { // `N / (M * N)` should be simplified to `1 / M`, but we only handle // simplifying to `M` in this function. if op == BinaryTypeOperator::Division && r_op == BinaryTypeOperator::Multiplication @@ -268,7 +268,7 @@ impl Type { let dummy_span = Span::default(); let rhs = rhs.evaluate_to_field_element(&kind, dummy_span).ok()?; - let Type::InfixExpr(l_type, l_op, l_rhs) = lhs.follow_bindings() else { + let Type::InfixExpr(l_type, l_op, l_rhs, _) = lhs.follow_bindings() else { return None; }; @@ -302,7 +302,7 @@ impl Type { let result = op.function(l_const, r_const, &lhs.infix_kind(rhs), dummy_span).ok()?; let constant = Type::Constant(result, lhs.infix_kind(rhs)); - Some(Type::InfixExpr(l_type, l_op, Box::new(constant))) + Some(Type::infix_expr(l_type, l_op, Box::new(constant))) } (Multiplication, Division) => { // We need to ensure the result divides evenly to preserve integer division semantics @@ -317,7 +317,7 @@ impl Type { let result = op.function(l_const, r_const, &lhs.infix_kind(rhs), dummy_span).ok()?; let constant = Box::new(Type::Constant(result, lhs.infix_kind(rhs))); - Some(Type::InfixExpr(l_type, l_op, constant)) + Some(Type::infix_expr(l_type, l_op, constant)) } } _ => None, @@ -331,13 +331,14 @@ impl Type { other: &Type, bindings: &mut TypeBindings, ) -> Result<(), UnificationError> { - if let Type::InfixExpr(lhs_a, op_a, rhs_a) = self { + if let Type::InfixExpr(lhs_a, op_a, rhs_a, _) = self { if let Some(inverse) = op_a.approx_inverse() { let kind = lhs_a.infix_kind(rhs_a); let dummy_span = Span::default(); if let Ok(rhs_a_value) = rhs_a.evaluate_to_field_element(&kind, dummy_span) { let rhs_a = Box::new(Type::Constant(rhs_a_value, kind)); - let new_other = Type::InfixExpr(Box::new(other.clone()), inverse, rhs_a); + let new_other = + Type::inverted_infix_expr(Box::new(other.clone()), inverse, rhs_a); let mut tmp_bindings = bindings.clone(); if lhs_a.try_unify(&new_other, &mut tmp_bindings).is_ok() { @@ -348,13 +349,14 @@ impl Type { } } - if let Type::InfixExpr(lhs_b, op_b, rhs_b) = other { + if let Type::InfixExpr(lhs_b, op_b, rhs_b, inversion) = other { if let Some(inverse) = op_b.approx_inverse() { let kind = lhs_b.infix_kind(rhs_b); let dummy_span = Span::default(); if let Ok(rhs_b_value) = rhs_b.evaluate_to_field_element(&kind, dummy_span) { let rhs_b = Box::new(Type::Constant(rhs_b_value, kind)); - let new_self = Type::InfixExpr(Box::new(self.clone()), inverse, rhs_b); + let new_self = + Type::InfixExpr(Box::new(self.clone()), inverse, rhs_b, !inversion); let mut tmp_bindings = bindings.clone(); if new_self.try_unify(lhs_b, &mut tmp_bindings).is_ok() { @@ -384,7 +386,7 @@ mod tests { TypeVariable::unbound(TypeVariableId(0), Kind::u32()), std::rc::Rc::new("N".to_owned()), ); - let n_minus_one = Type::InfixExpr( + let n_minus_one = Type::infix_expr( Box::new(n.clone()), BinaryTypeOperator::Subtraction, Box::new(Type::Constant(FieldElement::one(), Kind::u32())), @@ -392,7 +394,7 @@ mod tests { let checked_cast_n_minus_one = Type::CheckedCast { from: Box::new(n_minus_one.clone()), to: Box::new(n_minus_one) }; - let n_minus_one_plus_one = Type::InfixExpr( + let n_minus_one_plus_one = Type::infix_expr( Box::new(checked_cast_n_minus_one.clone()), BinaryTypeOperator::Addition, Box::new(Type::Constant(FieldElement::one(), Kind::u32())), @@ -405,7 +407,7 @@ mod tests { // We also want to check that if the `CheckedCast` is on the RHS then we'll still be able to canonicalize // the expression `1 + (N - 1)` to `N`. - let one_plus_n_minus_one = Type::InfixExpr( + let one_plus_n_minus_one = Type::infix_expr( Box::new(Type::Constant(FieldElement::one(), Kind::u32())), BinaryTypeOperator::Addition, Box::new(checked_cast_n_minus_one), @@ -423,13 +425,13 @@ mod tests { let x_type = Type::TypeVariable(x_var.clone()); let one = Type::Constant(FieldElement::one(), field_element_kind.clone()); - let lhs = Type::InfixExpr( + let lhs = Type::infix_expr( Box::new(x_type.clone()), BinaryTypeOperator::Addition, Box::new(one.clone()), ); let rhs = - Type::InfixExpr(Box::new(one), BinaryTypeOperator::Addition, Box::new(x_type.clone())); + Type::infix_expr(Box::new(one), BinaryTypeOperator::Addition, Box::new(x_type.clone())); // canonicalize let lhs = lhs.canonicalize(); @@ -546,7 +548,7 @@ mod proptests { 10, // We put up to 10 items per collection |inner| { (inner.clone(), any::(), inner) - .prop_map(|(lhs, op, rhs)| Type::InfixExpr(Box::new(lhs), op, Box::new(rhs))) + .prop_map(|(lhs, op, rhs)| Type::infix_expr(Box::new(lhs), op, Box::new(rhs))) }, ) } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index b0c8744ea8f..2f6f366c5c9 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -1309,7 +1309,7 @@ impl<'interner> Monomorphizer<'interner> { } HirType::MutableReference(element) => Self::check_type(element, location), - HirType::InfixExpr(lhs, _, rhs) => { + HirType::InfixExpr(lhs, _, rhs, _) => { Self::check_type(lhs, location)?; Self::check_type(rhs, location) } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 5eaeb43ec8f..04c94843f53 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -4045,3 +4045,25 @@ fn infers_lambda_argument_from_call_function_type_in_generic_call() { "#; assert_no_errors(src); } + +#[test] +fn regression_7088() { + // A test for code that initially broke when implementing inferring + // lambda parameter types from the function type related to the call + // the lambda is in (PR #7088). + let src = r#" + struct U60Repr {} + + impl U60Repr { + fn new(_: [Field; N * NumFieldSegments]) -> Self { + U60Repr {} + } + } + + fn main() { + let input: [Field; 6] = [0; 6]; + let _: U60Repr<3, 6> = U60Repr::new(input); + } + "#; + assert_no_errors(src); +} diff --git a/tooling/lsp/src/requests/completion.rs b/tooling/lsp/src/requests/completion.rs index 0d737e29ff7..a845fd4496f 100644 --- a/tooling/lsp/src/requests/completion.rs +++ b/tooling/lsp/src/requests/completion.rs @@ -622,7 +622,7 @@ impl<'a> NodeFinder<'a> { | Type::Forall(_, _) | Type::Constant(..) | Type::Quoted(_) - | Type::InfixExpr(_, _, _) + | Type::InfixExpr(..) | Type::Error => (), } diff --git a/tooling/lsp/src/requests/hover.rs b/tooling/lsp/src/requests/hover.rs index ef1246d752d..5d8c50fa47b 100644 --- a/tooling/lsp/src/requests/hover.rs +++ b/tooling/lsp/src/requests/hover.rs @@ -721,7 +721,7 @@ impl<'a> TypeLinksGatherer<'a> { self.gather_type_links(env); } Type::MutableReference(typ) => self.gather_type_links(typ), - Type::InfixExpr(lhs, _, rhs) => { + Type::InfixExpr(lhs, _, rhs, _) => { self.gather_type_links(lhs); self.gather_type_links(rhs); } diff --git a/tooling/lsp/src/trait_impl_method_stub_generator.rs b/tooling/lsp/src/trait_impl_method_stub_generator.rs index 2ae0d526f3e..eb1709e34d0 100644 --- a/tooling/lsp/src/trait_impl_method_stub_generator.rs +++ b/tooling/lsp/src/trait_impl_method_stub_generator.rs @@ -361,7 +361,7 @@ impl<'a> TraitImplMethodStubGenerator<'a> { Type::Forall(_, _) => { panic!("Shouldn't get a Type::Forall"); } - Type::InfixExpr(left, op, right) => { + Type::InfixExpr(left, op, right, _) => { self.append_type(left); self.string.push(' '); self.string.push_str(&op.to_string()); From 8129a93bf098cee9c797ed70b99eb89b09aa620a Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Fri, 17 Jan 2025 11:04:46 -0300 Subject: [PATCH 14/15] Somehow missed one case --- compiler/noirc_frontend/src/hir_def/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 03dba47b329..90e53c5c7a0 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -2873,7 +2873,7 @@ impl std::fmt::Debug for Type { write!(f, "&mut {element:?}") } Type::Quoted(quoted) => write!(f, "{}", quoted), - Type::InfixExpr(lhs, op, rhs) => write!(f, "({lhs:?} {op} {rhs:?})"), + Type::InfixExpr(lhs, op, rhs, _) => write!(f, "({lhs:?} {op} {rhs:?})"), } } } From bc844565585e1fe9ab14424e956b054ffd079e14 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Fri, 17 Jan 2025 12:28:43 -0300 Subject: [PATCH 15/15] Update compiler/noirc_frontend/src/hir_def/types.rs Co-authored-by: jfecher --- compiler/noirc_frontend/src/hir_def/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 90e53c5c7a0..513240a7495 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1368,7 +1368,7 @@ impl Type { // unification (that is, the compiler had `b = a / y` and ended up doing `y = b / a` where // `y` is `rhs` here) then we can simplify this to just `b` because there wasn't an actual // division in the original expression, so multiplying it back is just going back to the - // original `y`.gg + // original `y` if let Type::InfixExpr(rhs_lhs, rhs_op, rhs_rhs, true) = &*rhs { if op.approx_inverse() == Some(*rhs_op) && lhs == *rhs_rhs { return *rhs_lhs.clone();