diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index bc759f41e57..7d304990dd8 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -759,7 +759,8 @@ impl<'context> Elaborator<'context> { return None; } - let result = interpreter.call_function(function, comptime_args, location); + let bindings = interpreter.interner.get_instantiation_bindings(func).clone(); + let result = interpreter.call_function(function, comptime_args, bindings, location); let (expr_id, typ) = self.inline_comptime_value(result, location.span); Some((self.interner.expression(&expr_id), typ)) } diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index c7d18c39ba7..ae8237706cc 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -33,7 +33,7 @@ use crate::{ DefinitionId, DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, TraitId, TypeAliasId, }, parser::TopLevelStatement, - Shared, Type, TypeVariable, + Shared, Type, TypeBindings, TypeVariable, }; use crate::{ ast::{TraitBound, UnresolvedGeneric, UnresolvedGenerics}, @@ -1273,7 +1273,7 @@ impl<'context> Elaborator<'context> { let arguments = vec![(Value::TypeDefinition(struct_id), location)]; let value = interpreter - .call_function(function, arguments, location) + .call_function(function, arguments, TypeBindings::new(), location) .map_err(|error| error.into_compilation_error_pair())?; if value != Value::Unit { diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 8acb5d2d315..d2b98569bbb 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -8,6 +8,7 @@ use rustc_hash::FxHashMap as HashMap; use crate::ast::{BinaryOpKind, FunctionKind, IntegerBitSize, Signedness}; use crate::graph::CrateId; +use crate::monomorphization::{perform_instantiation_bindings, undo_instantiation_bindings}; use crate::token::Tokens; use crate::{ hir_def::{ @@ -28,7 +29,7 @@ use crate::{ }; use super::errors::{IResult, InterpreterError}; -use super::value::Value; +use super::value::{unwrap_rc, Value}; mod builtin; mod unquote; @@ -59,6 +60,19 @@ impl<'a> Interpreter<'a> { } pub(crate) fn call_function( + &mut self, + function: FuncId, + arguments: Vec<(Value, Location)>, + instantiation_bindings: TypeBindings, + location: Location, + ) -> IResult { + perform_instantiation_bindings(&instantiation_bindings); + let result = self.call_function_inner(function, arguments, location); + undo_instantiation_bindings(instantiation_bindings); + result + } + + fn call_function_inner( &mut self, function: FuncId, arguments: Vec<(Value, Location)>, @@ -200,7 +214,8 @@ impl<'a> Interpreter<'a> { ) -> IResult<()> { match pattern { HirPattern::Identifier(identifier) => { - self.define(identifier.id, typ, argument, location) + self.define(identifier.id, argument); + Ok(()) } HirPattern::Mutable(pattern, _) => { self.define_pattern(pattern, typ, argument, location) @@ -222,8 +237,6 @@ impl<'a> Interpreter<'a> { }, HirPattern::Struct(struct_type, pattern_fields, _) => { self.push_scope(); - self.type_check(typ, &argument, location)?; - self.type_check(struct_type, &argument, location)?; let res = match argument { Value::Struct(fields, struct_type) if fields.len() == pattern_fields.len() => { @@ -259,30 +272,8 @@ impl<'a> Interpreter<'a> { } /// Define a new variable in the current scope - fn define( - &mut self, - id: DefinitionId, - typ: &Type, - argument: Value, - location: Location, - ) -> IResult<()> { - // Temporarily disabled since this fails on generic types - // self.type_check(typ, &argument, location)?; + fn define(&mut self, id: DefinitionId, argument: Value) { self.current_scope_mut().insert(id, argument); - Ok(()) - } - - /// Mutate an existing variable, potentially from a prior scope. - /// Also type checks the value being assigned - fn checked_mutate( - &mut self, - id: DefinitionId, - typ: &Type, - argument: Value, - location: Location, - ) -> IResult<()> { - self.type_check(typ, &argument, location)?; - self.mutate(id, argument, location) } /// Mutate an existing variable, potentially from a prior scope @@ -321,15 +312,6 @@ impl<'a> Interpreter<'a> { Err(InterpreterError::NonComptimeVarReferenced { name, location }) } - fn type_check(&self, typ: &Type, value: &Value, location: Location) -> IResult<()> { - let typ = typ.follow_bindings(); - let value_type = value.get_type(); - - typ.try_unify(&value_type, &mut TypeBindings::new()).map_err(|_| { - InterpreterError::TypeMismatch { expected: typ, value: value.clone(), location } - }) - } - /// Evaluate an expression and return the result pub fn evaluate(&mut self, id: ExprId) -> IResult { match self.interner.expression(&id) { @@ -367,8 +349,9 @@ impl<'a> Interpreter<'a> { match &definition.kind { DefinitionKind::Function(function_id) => { - let typ = self.interner.id_type(id); - Ok(Value::Function(*function_id, typ)) + let typ = self.interner.id_type(id).follow_bindings(); + let bindings = Rc::new(self.interner.get_instantiation_bindings(id).clone()); + Ok(Value::Function(*function_id, typ, bindings)) } DefinitionKind::Local(_) => self.lookup(&ident), DefinitionKind::Global(global_id) => { @@ -539,7 +522,7 @@ impl<'a> Interpreter<'a> { } fn evaluate_array(&mut self, array: HirArrayLiteral, id: ExprId) -> IResult { - let typ = self.interner.id_type(id); + let typ = self.interner.id_type(id).follow_bindings(); match array { HirArrayLiteral::Standard(elements) => { @@ -936,7 +919,7 @@ impl<'a> Interpreter<'a> { }) .collect::>()?; - let typ = self.interner.id_type(id); + let typ = self.interner.id_type(id).follow_bindings(); Ok(Value::Struct(fields, typ)) } @@ -977,7 +960,10 @@ impl<'a> Interpreter<'a> { let location = self.interner.expr_location(&id); match function { - Value::Function(function_id, _) => self.call_function(function_id, arguments, location), + Value::Function(function_id, _, bindings) => { + let bindings = unwrap_rc(bindings); + self.call_function(function_id, arguments, bindings, location) + } Value::Closure(closure, env, _) => self.call_closure(closure, env, arguments, location), value => Err(InterpreterError::NonFunctionCalled { value, location }), } @@ -1006,7 +992,7 @@ impl<'a> Interpreter<'a> { }; if let Some(method) = method { - self.call_function(method, arguments, location) + self.call_function(method, arguments, TypeBindings::new(), location) } else { Err(InterpreterError::NoMethodFound { name: method_name.clone(), typ, location }) } @@ -1151,7 +1137,7 @@ impl<'a> Interpreter<'a> { let environment = try_vecmap(&lambda.captures, |capture| self.lookup_id(capture.ident.id, location))?; - let typ = self.interner.id_type(id); + let typ = self.interner.id_type(id).follow_bindings(); Ok(Value::Closure(lambda, environment, typ)) } @@ -1212,9 +1198,7 @@ impl<'a> Interpreter<'a> { fn store_lvalue(&mut self, lvalue: HirLValue, rhs: Value) -> IResult<()> { match lvalue { - HirLValue::Ident(ident, typ) => { - self.checked_mutate(ident.id, &typ, rhs, ident.location) - } + HirLValue::Ident(ident, typ) => self.mutate(ident.id, rhs, ident.location), HirLValue::Dereference { lvalue, element_type: _, location } => { match self.evaluate_lvalue(&lvalue)? { Value::Pointer(value) => { @@ -1233,7 +1217,7 @@ impl<'a> Interpreter<'a> { } Value::Struct(mut fields, typ) => { fields.insert(Rc::new(field_name.0.contents), rhs); - self.store_lvalue(*object, Value::Struct(fields, typ)) + self.store_lvalue(*object, Value::Struct(fields, typ.follow_bindings())) } value => { Err(InterpreterError::NonTupleOrStructInMemberAccess { value, location }) diff --git a/compiler/noirc_frontend/src/hir/comptime/tests.rs b/compiler/noirc_frontend/src/hir/comptime/tests.rs index 54d8f1f1fad..870f2bc458a 100644 --- a/compiler/noirc_frontend/src/hir/comptime/tests.rs +++ b/compiler/noirc_frontend/src/hir/comptime/tests.rs @@ -16,7 +16,7 @@ fn interpret_helper(src: &str, func_namespace: Vec) -> Result) -> Value { @@ -197,3 +197,18 @@ fn non_deterministic_recursion() { let result = interpret(program, vec!["main".into(), "fib".into()]); assert_eq!(result, Value::U64(55)); } + +#[test] +fn generic_functions() { + let program = " + fn main() -> pub u8 { + apply(1, |x| x + 1) + } + + fn apply(x: T, f: fn[Env](T) -> U) -> U { + f(x) + } + "; + let result = interpret(program, vec!["main".into(), "apply".into()]); + assert!(matches!(result, Value::U8(2))); +} diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index 2fecd868977..c956cdb5796 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -16,7 +16,7 @@ use crate::{ node_interner::{ExprId, FuncId}, parser::{self, NoirParser, TopLevelStatement}, token::{SpannedToken, Token, Tokens}, - QuotedType, Shared, Type, + QuotedType, Shared, Type, TypeBindings, }; use rustc_hash::FxHashMap as HashMap; @@ -36,7 +36,7 @@ pub enum Value { U32(u32), U64(u64), String(Rc), - Function(FuncId, Type), + Function(FuncId, Type, Rc), Closure(HirLambda, Vec, Type), Tuple(Vec), Struct(HashMap, Value>, Type), @@ -65,7 +65,7 @@ impl Value { let length = Type::Constant(value.len() as u32); Type::String(Box::new(length)) } - Value::Function(_, typ) => return Cow::Borrowed(typ), + Value::Function(_, typ, _) => return Cow::Borrowed(typ), Value::Closure(_, _, typ) => return Cow::Borrowed(typ), Value::Tuple(fields) => { Type::Tuple(vecmap(fields, |field| field.get_type().into_owned())) @@ -128,13 +128,14 @@ impl Value { ExpressionKind::Literal(Literal::Integer((value as u128).into(), false)) } Value::String(value) => ExpressionKind::Literal(Literal::Str(unwrap_rc(value))), - Value::Function(id, typ) => { + Value::Function(id, typ, bindings) => { let id = interner.function_definition_id(id); let impl_kind = ImplKind::NotATraitMethod; let ident = HirIdent { location, id, impl_kind }; let expr_id = interner.push_expr(HirExpression::Ident(ident, None)); interner.push_expr_location(expr_id, location.span, location.file); interner.push_expr_type(expr_id, typ); + interner.store_instantiation_bindings(expr_id, unwrap_rc(bindings)); ExpressionKind::Resolved(expr_id) } Value::Closure(_lambda, _env, _typ) => { @@ -247,10 +248,15 @@ impl Value { HirExpression::Literal(HirLiteral::Integer((value as u128).into(), false)) } Value::String(value) => HirExpression::Literal(HirLiteral::Str(unwrap_rc(value))), - Value::Function(id, _typ) => { + Value::Function(id, typ, bindings) => { let id = interner.function_definition_id(id); let impl_kind = ImplKind::NotATraitMethod; - HirExpression::Ident(HirIdent { location, id, impl_kind }, None) + let ident = HirIdent { location, id, impl_kind }; + let expr_id = interner.push_expr(HirExpression::Ident(ident, None)); + interner.push_expr_location(expr_id, location.span, location.file); + interner.push_expr_type(expr_id, typ); + interner.store_instantiation_bindings(expr_id, unwrap_rc(bindings)); + return Ok(expr_id); } Value::Closure(_lambda, _env, _typ) => { // TODO: How should a closure's environment be inlined? @@ -362,7 +368,7 @@ impl Display for Value { Value::U32(value) => write!(f, "{value}"), Value::U64(value) => write!(f, "{value}"), Value::String(value) => write!(f, "{value}"), - Value::Function(_, _) => write!(f, "(function)"), + Value::Function(..) => write!(f, "(function)"), Value::Closure(_, _, _) => write!(f, "(closure)"), Value::Tuple(fields) => { let fields = vecmap(fields, ToString::to_string); diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 7eb2afe804a..d589294d4fd 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -1816,13 +1816,13 @@ fn unwrap_struct_type(typ: &HirType) -> Vec<(String, HirType)> { } } -fn perform_instantiation_bindings(bindings: &TypeBindings) { +pub fn perform_instantiation_bindings(bindings: &TypeBindings) { for (var, binding) in bindings.values() { var.force_bind(binding.clone()); } } -fn undo_instantiation_bindings(bindings: TypeBindings) { +pub fn undo_instantiation_bindings(bindings: TypeBindings) { for (id, (var, _)) in bindings { var.unbind(id); }