diff --git a/compiler/noirc_frontend/src/elaborator/comptime.rs b/compiler/noirc_frontend/src/elaborator/comptime.rs new file mode 100644 index 00000000000..0cbd2db55da --- /dev/null +++ b/compiler/noirc_frontend/src/elaborator/comptime.rs @@ -0,0 +1,71 @@ +use std::mem::replace; + +use crate::{ + hir_def::expr::HirIdent, + macros_api::Expression, + node_interner::{DependencyId, ExprId, FuncId}, +}; + +use super::{Elaborator, FunctionContext, ResolverMeta}; + +impl<'context> Elaborator<'context> { + /// Elaborate an expression from the middle of a comptime scope. + /// When this happens we require additional information to know + /// what variables should be in scope. + pub fn elaborate_expression_from_comptime( + &mut self, + expr: Expression, + function: Option, + ) -> ExprId { + self.function_context.push(FunctionContext::default()); + let old_scope = self.scopes.end_function(); + self.scopes.start_function(); + let function_id = function.map(DependencyId::Function); + let old_item = replace(&mut self.current_item, function_id); + + // Note: recover_generics isn't good enough here because any existing generics + // should not be in scope of this new function + let old_generics = std::mem::take(&mut self.generics); + + let old_crate_and_module = function.map(|function| { + let meta = self.interner.function_meta(&function); + let old_crate = replace(&mut self.crate_id, meta.source_crate); + let old_module = replace(&mut self.local_module, meta.source_module); + self.introduce_generics_into_scope(meta.all_generics.clone()); + (old_crate, old_module) + }); + + self.populate_scope_from_comptime_scopes(); + let expr = self.elaborate_expression(expr).0; + + if let Some((old_crate, old_module)) = old_crate_and_module { + self.crate_id = old_crate; + self.local_module = old_module; + } + + self.generics = old_generics; + self.current_item = old_item; + self.scopes.end_function(); + self.scopes.0.push(old_scope); + self.check_and_pop_function_context(); + expr + } + + fn populate_scope_from_comptime_scopes(&mut self) { + // Take the comptime scope to be our runtime scope. + // Iterate from global scope to the most local scope so that the + // later definitions will naturally shadow the former. + for scope in &self.comptime_scopes { + for definition_id in scope.keys() { + let definition = self.interner.definition(*definition_id); + let name = definition.name.clone(); + let location = definition.location; + + let scope = self.scopes.get_mut_scope(); + let ident = HirIdent::non_trait_method(*definition_id, location); + let meta = ResolverMeta { ident, num_times_used: 0, warn_if_unused: false }; + scope.add_key_value(name.clone(), meta); + } + } + } +} diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 987c8c3f7ee..853098ce931 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -300,15 +300,21 @@ impl<'context> Elaborator<'context> { } let location = Location::new(span, self.file); - let hir_call = HirCallExpression { func, arguments, location }; - let typ = self.type_check_call(&hir_call, func_type, args, span); + let is_macro_call = call.is_macro_call; + let hir_call = HirCallExpression { func, arguments, location, is_macro_call }; + let mut typ = self.type_check_call(&hir_call, func_type, args, span); - if call.is_macro_call { - self.call_macro(func, comptime_args, location, typ) - .unwrap_or_else(|| (HirExpression::Error, Type::Error)) - } else { - (HirExpression::Call(hir_call), typ) + if is_macro_call { + if self.in_comptime_context() { + typ = self.interner.next_type_variable(); + } else { + return self + .call_macro(func, comptime_args, location, typ) + .unwrap_or_else(|| (HirExpression::Error, Type::Error)); + } } + + (HirExpression::Call(hir_call), typ) } fn elaborate_method_call( @@ -368,6 +374,7 @@ impl<'context> Elaborator<'context> { let location = Location::new(span, self.file); let method = method_call.method_name; let turbofish_generics = generics.clone(); + let is_macro_call = method_call.is_macro_call; let method_call = HirMethodCallExpression { method, object, arguments, location, generics }; @@ -377,6 +384,7 @@ impl<'context> Elaborator<'context> { let ((function_id, function_name), function_call) = method_call.into_function_call( &method_ref, object_type, + is_macro_call, location, self.interner, ); @@ -721,7 +729,7 @@ impl<'context> Elaborator<'context> { (id, typ) } - pub(super) fn inline_comptime_value( + pub fn inline_comptime_value( &mut self, value: Result, span: Span, @@ -801,14 +809,14 @@ impl<'context> Elaborator<'context> { for argument in arguments { match interpreter.evaluate(argument) { Ok(arg) => { - let location = interpreter.interner.expr_location(&argument); + let location = interpreter.elaborator.interner.expr_location(&argument); comptime_args.push((arg, location)); } Err(error) => errors.push((error.into(), file)), } } - let bindings = interpreter.interner.get_instantiation_bindings(func).clone(); + let bindings = interpreter.elaborator.interner.get_instantiation_bindings(func).clone(); let result = interpreter.call_function(function, comptime_args, bindings, location); if !errors.is_empty() { diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 5327e911c47..ccbe67a49b8 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -7,7 +7,7 @@ use std::{ use crate::{ ast::{FunctionKind, UnresolvedTraitConstraint}, hir::{ - comptime::{self, Interpreter, InterpreterError, Value}, + comptime::{Interpreter, InterpreterError, Value}, def_collector::{ dc_crate::{ filter_literal_globals, CompilationError, ImplMap, UnresolvedGlobal, @@ -60,6 +60,7 @@ use crate::{ macros_api::ItemVisibility, }; +mod comptime; mod expressions; mod lints; mod patterns; @@ -97,9 +98,9 @@ pub struct LambdaContext { pub struct Elaborator<'context> { scopes: ScopeForest, - errors: Vec<(CompilationError, FileId)>, + pub(crate) errors: Vec<(CompilationError, FileId)>, - interner: &'context mut NodeInterner, + pub(crate) interner: &'context mut NodeInterner, def_maps: &'context mut BTreeMap, @@ -167,7 +168,7 @@ pub struct Elaborator<'context> { /// Each value currently in scope in the comptime interpreter. /// Each element of the Vec represents a scope with every scope together making /// up all currently visible definitions. The first scope is always the global scope. - comptime_scopes: Vec>, + pub(crate) comptime_scopes: Vec>, /// The scope of --debug-comptime, or None if unset debug_comptime_in_file: Option, @@ -228,6 +229,15 @@ impl<'context> Elaborator<'context> { items: CollectedItems, debug_comptime_in_file: Option, ) -> Vec<(CompilationError, FileId)> { + Self::elaborate_and_return_self(context, crate_id, items, debug_comptime_in_file).errors + } + + pub fn elaborate_and_return_self( + context: &'context mut Context, + crate_id: CrateId, + items: CollectedItems, + debug_comptime_in_file: Option, + ) -> Self { let mut this = Self::new(context, crate_id, debug_comptime_in_file); // Filter out comptime items to execute their functions first if needed. @@ -238,7 +248,7 @@ impl<'context> Elaborator<'context> { let (comptime_items, runtime_items) = Self::filter_comptime_items(items); this.elaborate_items(comptime_items); this.elaborate_items(runtime_items); - this.errors + this } fn elaborate_items(&mut self, mut items: CollectedItems) { @@ -339,6 +349,21 @@ impl<'context> Elaborator<'context> { self.trait_id = None; } + fn introduce_generics_into_scope(&mut self, all_generics: Vec) { + // Introduce all numeric generics into scope + for generic in &all_generics { + if let Kind::Numeric(typ) = &generic.kind { + let definition = DefinitionKind::GenericType(generic.type_var.clone()); + let ident = Ident::new(generic.name.to_string(), generic.span); + let hir_ident = + self.add_variable_decl_inner(ident, false, false, false, definition); + self.interner.push_definition_type(hir_ident.id, *typ.clone()); + } + } + + self.generics = all_generics; + } + fn elaborate_function(&mut self, id: FuncId) { let func_meta = self.interner.func_meta.get_mut(&id); let func_meta = @@ -360,16 +385,7 @@ impl<'context> Elaborator<'context> { self.trait_bounds = func_meta.trait_constraints.clone(); self.function_context.push(FunctionContext::default()); - // Introduce all numeric generics into scope - for generic in &func_meta.all_generics { - if let Kind::Numeric(typ) = &generic.kind { - let definition = DefinitionKind::GenericType(generic.type_var.clone()); - let ident = Ident::new(generic.name.to_string(), generic.span); - let hir_ident = - self.add_variable_decl_inner(ident, false, false, false, definition); - self.interner.push_definition_type(hir_ident.id, *typ.clone()); - } - } + self.introduce_generics_into_scope(func_meta.all_generics.clone()); // The DefinitionIds for each parameter were already created in define_function_meta // so we need to reintroduce the same IDs into scope here. @@ -378,8 +394,6 @@ impl<'context> Elaborator<'context> { self.add_existing_variable_to_scope(name, parameter.clone(), true); } - self.generics = func_meta.all_generics.clone(); - self.declare_numeric_generics(&func_meta.parameters, func_meta.return_type()); self.add_trait_constraints_to_scope(&func_meta); @@ -758,6 +772,7 @@ impl<'context> Elaborator<'context> { is_trait_function, has_inline_attribute, source_crate: self.crate_id, + source_module: self.local_module, function_body: FunctionBody::Unresolved(func.kind, body, func.def.span), }; @@ -1626,8 +1641,12 @@ impl<'context> Elaborator<'context> { } } - fn setup_interpreter(&mut self) -> Interpreter { - Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id) + pub fn setup_interpreter<'local>(&'local mut self) -> Interpreter<'local, 'context> { + let current_function = match self.current_item { + Some(DependencyId::Function(function)) => Some(function), + _ => None, + }; + Interpreter::new(self, self.crate_id, current_function) } fn debug_comptime T>( diff --git a/compiler/noirc_frontend/src/elaborator/patterns.rs b/compiler/noirc_frontend/src/elaborator/patterns.rs index 99cdc86dc96..e24b6a3a067 100644 --- a/compiler/noirc_frontend/src/elaborator/patterns.rs +++ b/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -5,7 +5,6 @@ use rustc_hash::FxHashSet as HashSet; use crate::{ ast::{UnresolvedType, ERROR_IDENT}, hir::{ - comptime::Interpreter, def_collector::dc_crate::CompilationError, resolution::errors::ResolverError, type_check::{Source, TypeCheckError}, @@ -460,8 +459,7 @@ impl<'context> Elaborator<'context> { // Comptime variables must be replaced with their values if let Some(definition) = self.interner.try_definition(definition_id) { if definition.comptime && !self.in_comptime_context() { - let mut interpreter = - Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id); + let mut interpreter = self.setup_interpreter(); let value = interpreter.evaluate(id); return self.inline_comptime_value(value, span); } diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index a4440e34285..31ee8ef7200 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -7,6 +7,7 @@ use noirc_errors::Location; use rustc_hash::FxHashMap as HashMap; use crate::ast::{BinaryOpKind, FunctionKind, IntegerBitSize, Signedness}; +use crate::elaborator::Elaborator; use crate::graph::CrateId; use crate::hir_def::expr::ImplKind; use crate::macros_api::UnaryOp; @@ -40,28 +41,25 @@ mod builtin; mod unquote; #[allow(unused)] -pub struct Interpreter<'interner> { - /// To expand macros the Interpreter may mutate hir nodes within the NodeInterner - pub interner: &'interner mut NodeInterner, - - /// Each value currently in scope in the interpreter. - /// Each element of the Vec represents a scope with every scope together making - /// up all currently visible definitions. - scopes: &'interner mut Vec>, +pub struct Interpreter<'local, 'interner> { + /// To expand macros the Interpreter needs access to the Elaborator + pub elaborator: &'local mut Elaborator<'interner>, crate_id: CrateId, in_loop: bool, + + current_function: Option, } #[allow(unused)] -impl<'a> Interpreter<'a> { +impl<'local, 'interner> Interpreter<'local, 'interner> { pub(crate) fn new( - interner: &'a mut NodeInterner, - scopes: &'a mut Vec>, + elaborator: &'local mut Elaborator<'interner>, crate_id: CrateId, + current_function: Option, ) -> Self { - Self { interner, scopes, crate_id, in_loop: false } + Self { elaborator, crate_id, current_function, in_loop: false } } pub(crate) fn call_function( @@ -71,11 +69,16 @@ impl<'a> Interpreter<'a> { instantiation_bindings: TypeBindings, location: Location, ) -> IResult { - let trait_method = self.interner.get_trait_method_id(function); + let trait_method = self.elaborator.interner.get_trait_method_id(function); perform_instantiation_bindings(&instantiation_bindings); - let impl_bindings = perform_impl_bindings(self.interner, trait_method, function, location)?; + let impl_bindings = + perform_impl_bindings(self.elaborator.interner, trait_method, function, location)?; + let old_function = self.current_function.replace(function); + let result = self.call_function_inner(function, arguments, location); + + self.current_function = old_function; undo_instantiation_bindings(impl_bindings); undo_instantiation_bindings(instantiation_bindings); result @@ -87,7 +90,7 @@ impl<'a> Interpreter<'a> { arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { - let meta = self.interner.function_meta(&function); + let meta = self.elaborator.interner.function_meta(&function); if meta.parameters.len() != arguments.len() { return Err(InterpreterError::ArgumentCountMismatch { expected: meta.parameters.len(), @@ -96,11 +99,11 @@ impl<'a> Interpreter<'a> { }); } - let is_comptime = self.interner.function_modifiers(&function).is_comptime; + let is_comptime = self.elaborator.interner.function_modifiers(&function).is_comptime; if !is_comptime && meta.source_crate == self.crate_id { // Calling non-comptime functions from within the current crate is restricted // as non-comptime items will have not been elaborated yet. - let function = self.interner.function_name(&function).to_owned(); + let function = self.elaborator.interner.function_name(&function).to_owned(); return Err(InterpreterError::NonComptimeFnCallInSameCrate { function, location }); } @@ -115,10 +118,11 @@ impl<'a> Interpreter<'a> { self.define_pattern(parameter, typ, argument, arg_location)?; } - let function_body = self.interner.function(&function).try_as_expr().ok_or_else(|| { - let function = self.interner.function_name(&function).to_owned(); - InterpreterError::NonComptimeFnCallInSameCrate { function, location } - })?; + let function_body = + self.elaborator.interner.function(&function).try_as_expr().ok_or_else(|| { + let function = self.elaborator.interner.function_name(&function).to_owned(); + InterpreterError::NonComptimeFnCallInSameCrate { function, location } + })?; let result = self.evaluate(function_body)?; @@ -132,13 +136,13 @@ impl<'a> Interpreter<'a> { arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { - let attributes = self.interner.function_attributes(&function); + let attributes = self.elaborator.interner.function_attributes(&function); let func_attrs = attributes.function.as_ref() .expect("all builtin functions must contain a function attribute which contains the opcode which it links to"); if let Some(builtin) = func_attrs.builtin() { let builtin = builtin.clone(); - builtin::call_builtin(self.interner, &builtin, arguments, location) + builtin::call_builtin(self.elaborator.interner, &builtin, arguments, location) } else if let Some(foreign) = func_attrs.foreign() { let item = format!("Comptime evaluation for foreign functions like {foreign}"); Err(InterpreterError::Unimplemented { item, location }) @@ -150,7 +154,7 @@ impl<'a> Interpreter<'a> { Err(InterpreterError::Unimplemented { item, location }) } } else { - let name = self.interner.function_name(&function); + let name = self.elaborator.interner.function_name(&function); unreachable!("Non-builtin, lowlevel or oracle builtin fn '{name}'") } } @@ -190,8 +194,8 @@ impl<'a> Interpreter<'a> { pub(super) fn enter_function(&mut self) -> (bool, Vec>) { // Drain every scope except the global scope let mut scope = Vec::new(); - if self.scopes.len() > 1 { - scope = self.scopes.drain(1..).collect(); + if self.elaborator.comptime_scopes.len() > 1 { + scope = self.elaborator.comptime_scopes.drain(1..).collect(); } self.push_scope(); (std::mem::take(&mut self.in_loop), scope) @@ -201,21 +205,21 @@ impl<'a> Interpreter<'a> { self.in_loop = state.0; // Keep only the global scope - self.scopes.truncate(1); - self.scopes.append(&mut state.1); + self.elaborator.comptime_scopes.truncate(1); + self.elaborator.comptime_scopes.append(&mut state.1); } pub(super) fn push_scope(&mut self) { - self.scopes.push(HashMap::default()); + self.elaborator.comptime_scopes.push(HashMap::default()); } pub(super) fn pop_scope(&mut self) { - self.scopes.pop(); + self.elaborator.comptime_scopes.pop(); } fn current_scope_mut(&mut self) -> &mut HashMap { // the global scope is always at index zero, so this is always Some - self.scopes.last_mut().unwrap() + self.elaborator.comptime_scopes.last_mut().unwrap() } pub(super) fn define_pattern( @@ -298,7 +302,7 @@ impl<'a> Interpreter<'a> { return Ok(()); } - for scope in self.scopes.iter_mut().rev() { + for scope in self.elaborator.comptime_scopes.iter_mut().rev() { if let Entry::Occupied(mut entry) = scope.entry(id) { entry.insert(argument); return Ok(()); @@ -312,7 +316,7 @@ impl<'a> Interpreter<'a> { } pub fn lookup_id(&self, id: DefinitionId, location: Location) -> IResult { - for scope in self.scopes.iter().rev() { + for scope in self.elaborator.comptime_scopes.iter().rev() { if let Some(value) = scope.get(&id) { return Ok(value.clone()); } @@ -321,7 +325,8 @@ impl<'a> Interpreter<'a> { if id == DefinitionId::dummy_id() { Err(InterpreterError::VariableNotInScope { location }) } else { - let name = self.interner.definition_name(id).to_string(); + let name = self.elaborator.interner.definition_name(id).to_string(); + eprintln!("{name} not in scope"); Err(InterpreterError::NonComptimeVarReferenced { name, location }) } } @@ -339,7 +344,7 @@ impl<'a> Interpreter<'a> { /// This function should be used when that is not desired - e.g. when /// compiling a `&mut var` expression to grab the original reference. fn evaluate_no_dereference(&mut self, id: ExprId) -> IResult { - match self.interner.expression(&id) { + match self.elaborator.interner.expression(&id) { HirExpression::Ident(ident, _) => self.evaluate_ident(ident, id), HirExpression::Literal(literal) => self.evaluate_literal(literal, id), HirExpression::Block(block) => self.evaluate_block(block), @@ -359,33 +364,34 @@ impl<'a> Interpreter<'a> { HirExpression::Unquote(tokens) => { // An Unquote expression being found is indicative of a macro being // expanded within another comptime fn which we don't currently support. - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InterpreterError::UnquoteFoundDuringEvaluation { location }) } HirExpression::Error => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InterpreterError::ErrorNodeEncountered { location }) } } } pub(super) fn evaluate_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult { - let definition = self.interner.try_definition(ident.id).ok_or_else(|| { - let location = self.interner.expr_location(&id); + let definition = self.elaborator.interner.try_definition(ident.id).ok_or_else(|| { + let location = self.elaborator.interner.expr_location(&id); InterpreterError::VariableNotInScope { location } })?; if let ImplKind::TraitMethod(method, _, _) = ident.impl_kind { - let method_id = resolve_trait_method(self.interner, method, id)?; - let typ = self.interner.id_type(id).follow_bindings(); - let bindings = self.interner.get_instantiation_bindings(id).clone(); + let method_id = resolve_trait_method(self.elaborator.interner, method, id)?; + let typ = self.elaborator.interner.id_type(id).follow_bindings(); + let bindings = self.elaborator.interner.get_instantiation_bindings(id).clone(); return Ok(Value::Function(method_id, typ, Rc::new(bindings))); } match &definition.kind { DefinitionKind::Function(function_id) => { - let typ = self.interner.id_type(id).follow_bindings(); - let bindings = Rc::new(self.interner.get_instantiation_bindings(id).clone()); + let typ = self.elaborator.interner.id_type(id).follow_bindings(); + let bindings = + Rc::new(self.elaborator.interner.get_instantiation_bindings(id).clone()); Ok(Value::Function(*function_id, typ, bindings)) } DefinitionKind::Local(_) => self.lookup(&ident), @@ -395,10 +401,12 @@ impl<'a> Interpreter<'a> { Ok(value) } else { let let_ = - self.interner.get_global_let_statement(*global_id).ok_or_else(|| { - let location = self.interner.expr_location(&id); - InterpreterError::VariableNotInScope { location } - })?; + self.elaborator.interner.get_global_let_statement(*global_id).ok_or_else( + || { + let location = self.elaborator.interner.expr_location(&id); + InterpreterError::VariableNotInScope { location } + }, + )?; if let_.comptime { self.evaluate_let(let_.clone())?; @@ -413,10 +421,10 @@ impl<'a> Interpreter<'a> { }; if let Some(value) = value { - let typ = self.interner.id_type(id); + let typ = self.elaborator.interner.id_type(id); self.evaluate_integer((value as u128).into(), false, id) } else { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); let typ = Type::TypeVariable(type_variable.clone(), TypeVariableKind::Normal); Err(InterpreterError::NonIntegerArrayLength { typ, location }) } @@ -434,7 +442,7 @@ impl<'a> Interpreter<'a> { HirLiteral::Str(string) => Ok(Value::String(Rc::new(string))), HirLiteral::FmtStr(_, _) => { let item = "format strings in a comptime context".into(); - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InterpreterError::Unimplemented { item, location }) } HirLiteral::Array(array) => self.evaluate_array(array, id), @@ -448,8 +456,8 @@ impl<'a> Interpreter<'a> { is_negative: bool, id: ExprId, ) -> IResult { - let typ = self.interner.id_type(id).follow_bindings(); - let location = self.interner.expr_location(&id); + let typ = self.elaborator.interner.id_type(id).follow_bindings(); + let location = self.elaborator.interner.expr_location(&id); if let Type::FieldElement = &typ { Ok(Value::Field(value)) @@ -562,7 +570,7 @@ impl<'a> Interpreter<'a> { } fn evaluate_array(&mut self, array: HirArrayLiteral, id: ExprId) -> IResult { - let typ = self.interner.id_type(id).follow_bindings(); + let typ = self.elaborator.interner.id_type(id).follow_bindings(); match array { HirArrayLiteral::Standard(elements) => { @@ -580,7 +588,7 @@ impl<'a> Interpreter<'a> { let elements = (0..length).map(|_| element.clone()).collect(); Ok(Value::Array(elements, typ)) } else { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InterpreterError::NonIntegerArrayLength { typ: length, location }) } } @@ -600,7 +608,7 @@ impl<'a> Interpreter<'a> { _ => self.evaluate(prefix.rhs)?, }; - if self.interner.get_selected_impl_for_expression(id).is_some() { + if self.elaborator.interner.get_selected_impl_for_expression(id).is_some() { self.evaluate_overloaded_prefix(prefix, rhs, id) } else { self.evaluate_prefix_with_value(rhs, prefix.operator, id) @@ -625,7 +633,7 @@ impl<'a> Interpreter<'a> { Value::U32(value) => Ok(Value::U32(0 - value)), Value::U64(value) => Ok(Value::U64(0 - value)), value => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); let operator = "minus"; Err(InterpreterError::InvalidValueForUnary { value, location, operator }) } @@ -641,7 +649,7 @@ impl<'a> Interpreter<'a> { Value::U32(value) => Ok(Value::U32(!value)), Value::U64(value) => Ok(Value::U64(!value)), value => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InterpreterError::InvalidValueForUnary { value, location, operator: "not" }) } }, @@ -657,7 +665,7 @@ impl<'a> Interpreter<'a> { UnaryOp::Dereference { implicitly_added: _ } => match rhs { Value::Pointer(element, _) => Ok(element.borrow().clone()), value => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InterpreterError::NonPointerDereferenced { value, location }) } }, @@ -668,7 +676,7 @@ impl<'a> Interpreter<'a> { let lhs = self.evaluate(infix.lhs)?; let rhs = self.evaluate(infix.rhs)?; - if self.interner.get_selected_impl_for_expression(id).is_some() { + if self.elaborator.interner.get_selected_impl_for_expression(id).is_some() { return self.evaluate_overloaded_infix(infix, lhs, rhs, id); } @@ -685,7 +693,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs + rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs + rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "+" }) } }, @@ -700,7 +708,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs - rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs - rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "-" }) } }, @@ -715,7 +723,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs * rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs * rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "*" }) } }, @@ -730,7 +738,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs / rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs / rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "/" }) } }, @@ -745,7 +753,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs == rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs == rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "==" }) } }, @@ -760,7 +768,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs != rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs != rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "!=" }) } }, @@ -775,7 +783,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs < rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs < rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "<" }) } }, @@ -790,7 +798,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs <= rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs <= rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "<=" }) } }, @@ -805,7 +813,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs > rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs > rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: ">" }) } }, @@ -820,7 +828,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs >= rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs >= rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: ">=" }) } }, @@ -835,7 +843,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs & rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs & rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "&" }) } }, @@ -850,7 +858,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs | rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs | rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "|" }) } }, @@ -865,7 +873,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs ^ rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs ^ rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "^" }) } }, @@ -879,7 +887,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs >> rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs >> rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: ">>" }) } }, @@ -893,7 +901,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs << rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs << rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "<<" }) } }, @@ -907,7 +915,7 @@ impl<'a> Interpreter<'a> { (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs % rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs % rhs)), (lhs, rhs) => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); Err(InvalidValuesForBinary { lhs, rhs, location, operator: "%" }) } }, @@ -924,13 +932,13 @@ impl<'a> Interpreter<'a> { let method = infix.trait_method_id; let operator = infix.operator.kind; - let method_id = resolve_trait_method(self.interner, method, id)?; - let type_bindings = self.interner.get_instantiation_bindings(id).clone(); + let method_id = resolve_trait_method(self.elaborator.interner, method, id)?; + let type_bindings = self.elaborator.interner.get_instantiation_bindings(id).clone(); - let lhs = (lhs, self.interner.expr_location(&infix.lhs)); - let rhs = (rhs, self.interner.expr_location(&infix.rhs)); + let lhs = (lhs, self.elaborator.interner.expr_location(&infix.lhs)); + let rhs = (rhs, self.elaborator.interner.expr_location(&infix.rhs)); - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); let value = self.call_function(method_id, vec![lhs, rhs], type_bindings, location)?; // Certain operators add additional operations after the trait call: @@ -954,12 +962,12 @@ impl<'a> Interpreter<'a> { prefix.trait_method_id.expect("ice: expected prefix operator trait at this point"); let operator = prefix.operator; - let method_id = resolve_trait_method(self.interner, method, id)?; - let type_bindings = self.interner.get_instantiation_bindings(id).clone(); + let method_id = resolve_trait_method(self.elaborator.interner, method, id)?; + let type_bindings = self.elaborator.interner.get_instantiation_bindings(id).clone(); - let rhs = (rhs, self.interner.expr_location(&prefix.rhs)); + let rhs = (rhs, self.elaborator.interner.expr_location(&prefix.rhs)); - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); self.call_function(method_id, vec![rhs], type_bindings, location) } @@ -995,7 +1003,7 @@ impl<'a> Interpreter<'a> { let array = self.evaluate(index.collection)?; let index = self.evaluate(index.index)?; - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); let (array, index) = self.bounds_check(array, index, location)?; Ok(array[index].clone()) @@ -1059,7 +1067,7 @@ impl<'a> Interpreter<'a> { }) .collect::>()?; - let typ = self.interner.id_type(id).follow_bindings(); + let typ = self.elaborator.interner.id_type(id).follow_bindings(); Ok(Value::Struct(fields, typ)) } @@ -1079,13 +1087,13 @@ impl<'a> Interpreter<'a> { (fields, Type::Tuple(field_types)) } value => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); return Err(InterpreterError::NonTupleOrStructInMemberAccess { value, location }); } }; fields.get(&access.rhs.0.contents).cloned().ok_or_else(|| { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); let value = Value::Struct(fields, struct_type); let field_name = access.rhs.0.contents; InterpreterError::ExpectedStructToHaveField { value, field_name, location } @@ -1095,14 +1103,22 @@ impl<'a> Interpreter<'a> { fn evaluate_call(&mut self, call: HirCallExpression, id: ExprId) -> IResult { let function = self.evaluate(call.func)?; let arguments = try_vecmap(call.arguments, |arg| { - Ok((self.evaluate(arg)?, self.interner.expr_location(&arg))) + Ok((self.evaluate(arg)?, self.elaborator.interner.expr_location(&arg))) })?; - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); match function { Value::Function(function_id, _, bindings) => { let bindings = unwrap_rc(bindings); - self.call_function(function_id, arguments, bindings, location) + let mut result = self.call_function(function_id, arguments, bindings, location)?; + if call.is_macro_call { + let expr = result.into_expression(self.elaborator.interner, location)?; + let expr = self + .elaborator + .elaborate_expression_from_comptime(expr, self.current_function); + result = self.evaluate(expr)?; + } + Ok(result) } Value::Closure(closure, env, _) => self.call_closure(closure, env, arguments, location), value => Err(InterpreterError::NonFunctionCalled { value, location }), @@ -1116,19 +1132,22 @@ impl<'a> Interpreter<'a> { ) -> IResult { let object = self.evaluate(call.object)?; let arguments = try_vecmap(call.arguments, |arg| { - Ok((self.evaluate(arg)?, self.interner.expr_location(&arg))) + Ok((self.evaluate(arg)?, self.elaborator.interner.expr_location(&arg))) })?; - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); let typ = object.get_type().follow_bindings(); let method_name = &call.method.0.contents; // TODO: Traits let method = match &typ { - Type::Struct(struct_def, _) => { - self.interner.lookup_method(&typ, struct_def.borrow().id, method_name, false) - } - _ => self.interner.lookup_primitive_method(&typ, method_name), + Type::Struct(struct_def, _) => self.elaborator.interner.lookup_method( + &typ, + struct_def.borrow().id, + method_name, + false, + ), + _ => self.elaborator.interner.lookup_primitive_method(&typ, method_name), }; if let Some(method) = method { @@ -1140,7 +1159,7 @@ impl<'a> Interpreter<'a> { fn evaluate_cast(&mut self, cast: &HirCastExpression, id: ExprId) -> IResult { let evaluated_lhs = self.evaluate(cast.lhs)?; - Self::evaluate_cast_one_step(cast, id, evaluated_lhs, self.interner) + Self::evaluate_cast_one_step(cast, id, evaluated_lhs, self.elaborator.interner) } /// evaluate_cast without recursion @@ -1242,7 +1261,7 @@ impl<'a> Interpreter<'a> { let condition = match self.evaluate(if_.condition)? { Value::Bool(value) => value, value => { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); return Err(InterpreterError::NonBoolUsedInIf { value, location }); } }; @@ -1273,22 +1292,22 @@ impl<'a> Interpreter<'a> { } fn evaluate_lambda(&mut self, lambda: HirLambda, id: ExprId) -> IResult { - let location = self.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&id); let environment = try_vecmap(&lambda.captures, |capture| self.lookup_id(capture.ident.id, location))?; - let typ = self.interner.id_type(id).follow_bindings(); + let typ = self.elaborator.interner.id_type(id).follow_bindings(); Ok(Value::Closure(lambda, environment, typ)) } fn evaluate_quote(&mut self, mut tokens: Tokens, expr_id: ExprId) -> IResult { - let location = self.interner.expr_location(&expr_id); + let location = self.elaborator.interner.expr_location(&expr_id); tokens = self.substitute_unquoted_values_into_tokens(tokens, location)?; Ok(Value::Code(Rc::new(tokens))) } pub fn evaluate_statement(&mut self, statement: StmtId) -> IResult { - match self.interner.statement(&statement) { + match self.elaborator.interner.statement(&statement) { HirStatement::Let(let_) => self.evaluate_let(let_), HirStatement::Constrain(constrain) => self.evaluate_constrain(constrain), HirStatement::Assign(assign) => self.evaluate_assign(assign), @@ -1302,7 +1321,7 @@ impl<'a> Interpreter<'a> { Ok(Value::Unit) } HirStatement::Error => { - let location = self.interner.id_location(statement); + let location = self.elaborator.interner.id_location(statement); Err(InterpreterError::ErrorNodeEncountered { location }) } } @@ -1310,7 +1329,7 @@ impl<'a> Interpreter<'a> { pub fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult { let rhs = self.evaluate(let_.expression)?; - let location = self.interner.expr_location(&let_.expression); + let location = self.elaborator.interner.expr_location(&let_.expression); self.define_pattern(&let_.pattern, &let_.r#type, rhs, location)?; Ok(Value::Unit) } @@ -1319,12 +1338,12 @@ impl<'a> Interpreter<'a> { match self.evaluate(constrain.0)? { Value::Bool(true) => Ok(Value::Unit), Value::Bool(false) => { - let location = self.interner.expr_location(&constrain.0); + let location = self.elaborator.interner.expr_location(&constrain.0); let message = constrain.2.and_then(|expr| self.evaluate(expr).ok()); Err(InterpreterError::FailingConstraint { location, message }) } value => { - let location = self.interner.expr_location(&constrain.0); + let location = self.elaborator.interner.expr_location(&constrain.0); Err(InterpreterError::NonBoolUsedInConstrain { value, location }) } } @@ -1444,7 +1463,7 @@ impl<'a> Interpreter<'a> { Value::U32(value) => Ok((value as i128, |i| Value::U32(i as u32))), Value::U64(value) => Ok((value as i128, |i| Value::U64(i as u64))), value => { - let location = this.interner.expr_location(&expr); + let location = this.elaborator.interner.expr_location(&expr); Err(InterpreterError::NonIntegerUsedInLoop { value, location }) } } @@ -1476,7 +1495,7 @@ impl<'a> Interpreter<'a> { if self.in_loop { Err(InterpreterError::Break) } else { - let location = self.interner.statement_location(id); + let location = self.elaborator.interner.statement_location(id); Err(InterpreterError::BreakNotInLoop { location }) } } @@ -1485,7 +1504,7 @@ impl<'a> Interpreter<'a> { if self.in_loop { Err(InterpreterError::Continue) } else { - let location = self.interner.statement_location(id); + let location = self.elaborator.interner.statement_location(id); Err(InterpreterError::ContinueNotInLoop { location }) } } diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/unquote.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/unquote.rs index a1ceb27afb2..94a848b891d 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/unquote.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/unquote.rs @@ -7,7 +7,7 @@ use crate::{ use super::Interpreter; -impl<'a> Interpreter<'a> { +impl<'local, 'interner> Interpreter<'local, 'interner> { /// Evaluates any expressions within UnquoteMarkers in the given token list /// and replaces the expression held by the marker with the evaluated value /// in expression form. @@ -27,7 +27,8 @@ impl<'a> Interpreter<'a> { // turning it into a Quoted block (which would add `quote`, `{`, and `}` tokens). Value::Code(stream) => new_tokens.extend(unwrap_rc(stream).0), value => { - let new_id = value.into_hir_expression(self.interner, location)?; + let new_id = + value.into_hir_expression(self.elaborator.interner, location)?; let new_token = Token::UnquoteMarker(new_id); new_tokens.push(SpannedToken::new(new_token, span)); } diff --git a/compiler/noirc_frontend/src/hir/comptime/tests.rs b/compiler/noirc_frontend/src/hir/comptime/tests.rs index 0f58a2cda95..b4ffa1bd01d 100644 --- a/compiler/noirc_frontend/src/hir/comptime/tests.rs +++ b/compiler/noirc_frontend/src/hir/comptime/tests.rs @@ -8,19 +8,15 @@ use noirc_arena::Index; use noirc_errors::Location; use super::errors::InterpreterError; -use super::interpreter::Interpreter; use super::value::Value; use crate::elaborator::Elaborator; -use crate::graph::CrateId; use crate::hir::def_collector::dc_crate::DefCollector; use crate::hir::def_collector::dc_mod::collect_defs; use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleData}; use crate::hir::{Context, ParsedFiles}; -use crate::macros_api::NodeInterner; -use crate::node_interner::FuncId; use crate::parser::parse_program; -fn elaborate_src_code(src: &str) -> (NodeInterner, FuncId) { +fn interpret_helper(src: &str) -> Result { let file = FileId::default(); // Can't use Index::test_new here for some reason, even with #[cfg(test)]. @@ -47,21 +43,15 @@ fn elaborate_src_code(src: &str) -> (NodeInterner, FuncId) { collect_defs(&mut collector, ast, FileId::dummy(), module_id, krate, &mut context, &[]); context.def_maps.insert(krate, collector.def_map); - let errors = Elaborator::elaborate(&mut context, krate, collector.items, None); - assert_eq!(errors.len(), 0); - let main = context.get_main_function(&krate).expect("Expected 'main' function"); + let mut elaborator = + Elaborator::elaborate_and_return_self(&mut context, krate, collector.items, None); + assert_eq!(elaborator.errors.len(), 0); - (context.def_interner, main) -} - -fn interpret_helper(src: &str) -> Result { - let (mut interner, main_id) = elaborate_src_code(src); - let mut scopes = vec![HashMap::default()]; - let mut interpreter = Interpreter::new(&mut interner, &mut scopes, CrateId::Root(0)); + let mut interpreter = elaborator.setup_interpreter(); let no_location = Location::dummy(); - interpreter.call_function(main_id, Vec::new(), HashMap::new(), no_location) + interpreter.call_function(main, Vec::new(), HashMap::new(), no_location) } fn interpret(src: &str) -> Value { diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index 21c222b481c..e85d30f0c32 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -171,6 +171,7 @@ pub struct HirCallExpression { pub func: ExprId, pub arguments: Vec, pub location: Location, + pub is_macro_call: bool, } /// These nodes are temporary, they're @@ -208,6 +209,7 @@ impl HirMethodCallExpression { mut self, method: &HirMethodReference, object_type: Type, + is_macro_call: bool, location: Location, interner: &mut NodeInterner, ) -> ((ExprId, HirIdent), HirCallExpression) { @@ -232,7 +234,7 @@ impl HirMethodCallExpression { let func_var = HirIdent { location, id, impl_kind }; let func = interner.push_expr(HirExpression::Ident(func_var.clone(), self.generics)); interner.push_expr_location(func, location.span, location.file); - let expr = HirCallExpression { func, arguments, location }; + let expr = HirCallExpression { func, arguments, location, is_macro_call }; ((func, func_var), expr) } } diff --git a/compiler/noirc_frontend/src/hir_def/function.rs b/compiler/noirc_frontend/src/hir_def/function.rs index dc563a5f65f..9fa480e88b0 100644 --- a/compiler/noirc_frontend/src/hir_def/function.rs +++ b/compiler/noirc_frontend/src/hir_def/function.rs @@ -6,6 +6,7 @@ use super::stmt::HirPattern; use super::traits::TraitConstraint; use crate::ast::{FunctionKind, FunctionReturnType, Visibility}; use crate::graph::CrateId; +use crate::hir::def_map::LocalModuleId; use crate::macros_api::{BlockExpression, StructId}; use crate::node_interner::{ExprId, NodeInterner, TraitImplId}; use crate::{ResolvedGeneric, Type}; @@ -154,6 +155,9 @@ pub struct FuncMeta { /// The crate this function was defined in pub source_crate: CrateId, + + /// The module this function was defined in + pub source_module: LocalModuleId, } #[derive(Debug, Clone)] diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 37cda2bd04d..0ec975a04db 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -2118,7 +2118,8 @@ fn convert_array_expression_to_slice( interner.push_expr_location(argument, location.span, location.file); let arguments = vec![argument]; - let call = HirExpression::Call(HirCallExpression { func, arguments, location }); + let is_macro_call = false; + let call = HirExpression::Call(HirCallExpression { func, arguments, location, is_macro_call }); interner.replace_expr(&expression, call); interner.push_expr_location(func, location.span, location.file); diff --git a/test_programs/compile_success_empty/macros_in_comptime/Nargo.toml b/test_programs/compile_success_empty/macros_in_comptime/Nargo.toml new file mode 100644 index 00000000000..831fa270863 --- /dev/null +++ b/test_programs/compile_success_empty/macros_in_comptime/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "macros_in_comptime" +type = "bin" +authors = [""] +compiler_version = ">=0.32.0" + +[dependencies] diff --git a/test_programs/compile_success_empty/macros_in_comptime/src/main.nr b/test_programs/compile_success_empty/macros_in_comptime/src/main.nr new file mode 100644 index 00000000000..52567025e23 --- /dev/null +++ b/test_programs/compile_success_empty/macros_in_comptime/src/main.nr @@ -0,0 +1,49 @@ +use std::field::modulus_num_bits; +use std::meta::unquote; + +fn main() { + comptime + { + foo::<3>(5); + submodule::bar(); + } +} + +// Call a different function from the interpreter, then have the +// elaborator switch to the middle of foo from its previous scope in main +unconstrained comptime fn foo(x: Field) { + assert(modulus_num_bits() != 0); + + let cond = quote { modulus_num_bits() != 0 }; + assert(unquote!(cond)); + + // Use a comptime parameter in scope + assert_eq(5, x); + assert_eq(5, unquote!(quote { x })); + + // Use a generic in scope + assert_eq(3, N); + assert_eq(3, unquote!(quote { N })); + + // Use `break` which only unconstrained functions can do. + // This ensures the elaborator knows we're switching from `main` to `foo` + for _ in 0..0 { + break; + } + + let loop = quote { for _ in 0..0 { break; } }; + unquote!(loop); +} + +mod submodule { + use std::field::modulus_be_bytes; + use std::meta::unquote; + + pub comptime fn bar() { + // Use a function only in scope in this module + assert(modulus_be_bytes().len() != 0); + + let cond = quote { modulus_be_bytes().len() != 0 }; + assert(unquote!(cond)); + } +}