Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add TraitConstraint type #5499

Merged
merged 8 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aztec_macros/src/transforms/note_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub fn generate_note_interface_impl(module: &mut SortedModule) -> Result<(), Azt
generics: vec![],
methods: vec![],
where_clause: vec![],
is_comptime: false,
};
module.impls.push(default_impl.clone());
module.impls.last_mut().unwrap()
Expand Down
1 change: 1 addition & 0 deletions aztec_macros/src/transforms/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@
///
/// To:
///
/// impl<Context> Storage<Contex> {

Check warning on line 176 in aztec_macros/src/transforms/storage.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Contex)
/// fn init(context: Context) -> Self {
/// Storage {
/// a_map: Map::new(context, 0, |context, slot| {
Expand Down Expand Up @@ -248,6 +248,7 @@
methods: vec![(init, Span::default())],

where_clause: vec![],
is_comptime: false,
};
module.impls.push(storage_impl);

Expand Down
13 changes: 1 addition & 12 deletions compiler/noirc_frontend/src/ast/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,7 @@ pub struct NoirStruct {
pub generics: UnresolvedGenerics,
pub fields: Vec<(Ident, UnresolvedType)>,
pub span: Span,
}

impl NoirStruct {
pub fn new(
name: Ident,
attributes: Vec<SecondaryAttribute>,
generics: UnresolvedGenerics,
fields: Vec<(Ident, UnresolvedType)>,
span: Span,
) -> NoirStruct {
NoirStruct { name, attributes, generics, fields, span }
}
pub is_comptime: bool,
}

impl Display for NoirStruct {
Expand Down
5 changes: 4 additions & 1 deletion compiler/noirc_frontend/src/ast/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub struct TypeImpl {
pub generics: UnresolvedGenerics,
pub where_clause: Vec<UnresolvedTraitConstraint>,
pub methods: Vec<(NoirFunction, Span)>,
pub is_comptime: bool,
}

/// Ast node for an implementation of a trait for a particular type
Expand All @@ -69,6 +70,8 @@ pub struct NoirTraitImpl {
pub where_clause: Vec<UnresolvedTraitConstraint>,

pub items: Vec<TraitImplItem>,

pub is_comptime: bool,
}

/// Represents a simple trait constraint such as `where Foo: TraitY<U, V>`
Expand All @@ -84,7 +87,7 @@ pub struct UnresolvedTraitConstraint {
}

/// Represents a single trait bound, such as `TraitX` or `TraitY<U, V>`
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct TraitBound {
pub trait_path: Path,
pub trait_id: Option<TraitId>, // initially None, gets assigned during DC
Expand Down
152 changes: 85 additions & 67 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1636,17 +1636,25 @@ impl<'context> Elaborator<'context> {
function_sets.push(UnresolvedFunctions { functions, file_id, trait_id, self_type });
}

let (comptime_trait_impls, trait_impls) =
items.trait_impls.into_iter().partition(|trait_impl| trait_impl.is_comptime);

let (comptime_structs, structs) =
items.types.into_iter().partition(|typ| typ.1.struct_def.is_comptime);

let comptime = CollectedItems {
functions: comptime_function_sets,
types: BTreeMap::new(),
types: comptime_structs,
type_aliases: BTreeMap::new(),
traits: BTreeMap::new(),
trait_impls: Vec::new(),
trait_impls: comptime_trait_impls,
globals: Vec::new(),
impls: rustc_hash::FxHashMap::default(),
};

items.functions = function_sets;
items.trait_impls = trait_impls;
items.types = structs;
(comptime, items)
}

Expand All @@ -1657,75 +1665,85 @@ impl<'context> Elaborator<'context> {
location: Location,
) {
for item in items {
match item {
TopLevelStatement::Function(function) => {
let id = self.interner.push_empty_fn();
let module = self.module_id();
self.interner.push_function(id, &function.def, module, location);
let functions = vec![(self.local_module, id, function)];
generated_items.functions.push(UnresolvedFunctions {
file_id: self.file,
functions,
trait_id: None,
self_type: None,
});
}
TopLevelStatement::TraitImpl(mut trait_impl) => {
let methods = dc_mod::collect_trait_impl_functions(
self.interner,
&mut trait_impl,
self.crate_id,
self.file,
self.local_module,
);
self.add_item(item, generated_items, location);
}
}

generated_items.trait_impls.push(UnresolvedTraitImpl {
file_id: self.file,
module_id: self.local_module,
trait_generics: trait_impl.trait_generics,
trait_path: trait_impl.trait_name,
object_type: trait_impl.object_type,
methods,
generics: trait_impl.impl_generics,
where_clause: trait_impl.where_clause,

// These last fields are filled in later
trait_id: None,
impl_id: None,
resolved_object_type: None,
resolved_generics: Vec::new(),
resolved_trait_generics: Vec::new(),
});
}
TopLevelStatement::Global(global) => {
let (global, error) = dc_mod::collect_global(
self.interner,
self.def_maps.get_mut(&self.crate_id).unwrap(),
global,
self.file,
self.local_module,
);
fn add_item(
&mut self,
item: TopLevelStatement,
generated_items: &mut CollectedItems,
location: Location,
) {
match item {
TopLevelStatement::Function(function) => {
let id = self.interner.push_empty_fn();
let module = self.module_id();
self.interner.push_function(id, &function.def, module, location);
let functions = vec![(self.local_module, id, function)];
generated_items.functions.push(UnresolvedFunctions {
file_id: self.file,
functions,
trait_id: None,
self_type: None,
});
}
TopLevelStatement::TraitImpl(mut trait_impl) => {
let methods = dc_mod::collect_trait_impl_functions(
self.interner,
&mut trait_impl,
self.crate_id,
self.file,
self.local_module,
);

generated_items.globals.push(global);
if let Some(error) = error {
self.errors.push(error);
}
}
// Assume that an error has already been issued
TopLevelStatement::Error => (),

TopLevelStatement::Module(_)
| TopLevelStatement::Import(_)
| TopLevelStatement::Struct(_)
| TopLevelStatement::Trait(_)
| TopLevelStatement::Impl(_)
| TopLevelStatement::TypeAlias(_)
| TopLevelStatement::SubModule(_) => {
let item = item.to_string();
let error = InterpreterError::UnsupportedTopLevelItemUnquote { item, location };
self.errors.push(error.into_compilation_error_pair());
generated_items.trait_impls.push(UnresolvedTraitImpl {
file_id: self.file,
module_id: self.local_module,
trait_generics: trait_impl.trait_generics,
trait_path: trait_impl.trait_name,
object_type: trait_impl.object_type,
methods,
generics: trait_impl.impl_generics,
where_clause: trait_impl.where_clause,
is_comptime: trait_impl.is_comptime,

// These last fields are filled in later
trait_id: None,
impl_id: None,
resolved_object_type: None,
resolved_generics: Vec::new(),
resolved_trait_generics: Vec::new(),
});
}
TopLevelStatement::Global(global) => {
let (global, error) = dc_mod::collect_global(
self.interner,
self.def_maps.get_mut(&self.crate_id).unwrap(),
global,
self.file,
self.local_module,
);

generated_items.globals.push(global);
if let Some(error) = error {
self.errors.push(error);
}
}
// Assume that an error has already been issued
TopLevelStatement::Error => (),

TopLevelStatement::Module(_)
| TopLevelStatement::Import(_)
| TopLevelStatement::Struct(_)
| TopLevelStatement::Trait(_)
| TopLevelStatement::Impl(_)
| TopLevelStatement::TypeAlias(_)
| TopLevelStatement::SubModule(_) => {
let item = item.to_string();
let error = InterpreterError::UnsupportedTopLevelItemUnquote { item, location };
self.errors.push(error.into_compilation_error_pair());
}
}
}

Expand Down
93 changes: 86 additions & 7 deletions compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use std::rc::Rc;
use std::{
hash::{Hash, Hasher},
rc::Rc,
};

use chumsky::Parser;
use noirc_errors::Location;

use crate::{
ast::IntegerBitSize,
ast::{IntegerBitSize, TraitBound},
hir::comptime::{errors::IResult, InterpreterError, Value},
macros_api::{NodeInterner, Signedness},
parser,
token::{SpannedToken, Token, Tokens},
QuotedType, Type,
};
Expand All @@ -29,6 +34,9 @@ pub(super) fn call_builtin(
"struct_def_as_type" => struct_def_as_type(interner, arguments, location),
"struct_def_fields" => struct_def_fields(interner, arguments, location),
"struct_def_generics" => struct_def_generics(interner, arguments, location),
"trait_constraint_eq" => trait_constraint_eq(interner, arguments, location),
"trait_constraint_hash" => trait_constraint_hash(interner, arguments, location),
"quoted_as_trait_constraint" => quoted_as_trait_constraint(interner, arguments, location),
_ => {
let item = format!("Comptime evaluation for builtin function {name}");
Err(InterpreterError::Unimplemented { item, location })
Expand Down Expand Up @@ -79,6 +87,26 @@ fn get_u32(value: Value, location: Location) -> IResult<u32> {
}
}

fn get_trait_constraint(value: Value, location: Location) -> IResult<TraitBound> {
match value {
Value::TraitConstraint(bound) => Ok(bound),
value => {
let expected = Type::Quoted(QuotedType::TraitConstraint);
Err(InterpreterError::TypeMismatch { expected, value, location })
}
}
}

fn get_quoted(value: Value, location: Location) -> IResult<Rc<Tokens>> {
match value {
Value::Code(tokens) => Ok(tokens),
value => {
let expected = Type::Quoted(QuotedType::Quoted);
Err(InterpreterError::TypeMismatch { expected, value, location })
}
}
}

fn array_len(
interner: &NodeInterner,
mut arguments: Vec<(Value, Location)>,
Expand Down Expand Up @@ -231,7 +259,7 @@ fn slice_remove(
interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> Result<Value, InterpreterError> {
) -> IResult<Value> {
check_argument_count(2, &arguments, location)?;

let index = get_u32(arguments.pop().unwrap().0, location)? as usize;
Expand All @@ -257,7 +285,7 @@ fn slice_push_front(
interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> Result<Value, InterpreterError> {
) -> IResult<Value> {
check_argument_count(2, &arguments, location)?;

let (element, _) = arguments.pop().unwrap();
Expand All @@ -270,7 +298,7 @@ fn slice_pop_front(
interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> Result<Value, InterpreterError> {
) -> IResult<Value> {
check_argument_count(1, &arguments, location)?;

let (mut values, typ) = get_slice(interner, arguments.pop().unwrap().0, location)?;
Expand All @@ -284,7 +312,7 @@ fn slice_pop_back(
interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> Result<Value, InterpreterError> {
) -> IResult<Value> {
check_argument_count(1, &arguments, location)?;

let (mut values, typ) = get_slice(interner, arguments.pop().unwrap().0, location)?;
Expand All @@ -298,7 +326,7 @@ fn slice_insert(
interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> Result<Value, InterpreterError> {
) -> IResult<Value> {
check_argument_count(3, &arguments, location)?;

let (element, _) = arguments.pop().unwrap();
Expand All @@ -307,3 +335,54 @@ fn slice_insert(
values.insert(index as usize, element);
Ok(Value::Slice(values, typ))
}

// fn as_trait_constraint(quoted: Quoted) -> TraitConstraint
fn quoted_as_trait_constraint(
_interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
check_argument_count(1, &arguments, location)?;

let tokens = get_quoted(arguments.pop().unwrap().0, location)?;
let quoted = tokens.as_ref().clone();

let trait_bound = parser::trait_bound().parse(quoted).map_err(|mut errors| {
let error = errors.swap_remove(0);
let rule = "a trait constraint";
InterpreterError::FailedToParseMacro { error, tokens, rule, file: location.file }
})?;

Ok(Value::TraitConstraint(trait_bound))
}

// fn constraint_hash(constraint: TraitConstraint) -> Field
fn trait_constraint_hash(
_interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
check_argument_count(1, &arguments, location)?;

let bound = get_trait_constraint(arguments.pop().unwrap().0, location)?;

let mut hasher = std::collections::hash_map::DefaultHasher::new();
bound.hash(&mut hasher);
let hash = hasher.finish();

Ok(Value::Field((hash as u128).into()))
}

// fn constraint_eq(constraint_a: TraitConstraint, constraint_b: TraitConstraint) -> bool
fn trait_constraint_eq(
_interner: &mut NodeInterner,
mut arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
check_argument_count(2, &arguments, location)?;

let constraint_b = get_trait_constraint(arguments.pop().unwrap().0, location)?;
let constraint_a = get_trait_constraint(arguments.pop().unwrap().0, location)?;

Ok(Value::Bool(constraint_a == constraint_b))
}
Loading
Loading