Skip to content

Commit

Permalink
fix: allow calling trait impl method from struct if multiple impls ex…
Browse files Browse the repository at this point in the history
…ist (#7124)
  • Loading branch information
asterite authored Jan 21, 2025
1 parent 521f5ce commit 966d8a6
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 49 deletions.
7 changes: 7 additions & 0 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use super::{
};
use crate::ast::UnresolvedTypeData;
use crate::elaborator::types::SELF_TYPE_NAME;
use crate::elaborator::Turbofish;
use crate::lexer::token::SpannedToken;
use crate::node_interner::{
InternedExpressionKind, InternedPattern, InternedStatementKind, NodeInterner,
Expand Down Expand Up @@ -535,6 +536,12 @@ impl PathSegment {
pub fn turbofish_span(&self) -> Span {
Span::from(self.ident.span().end()..self.span.end())
}

pub fn turbofish(&self) -> Option<Turbofish> {
self.generics
.as_ref()
.map(|generics| Turbofish { span: self.turbofish_span(), generics: generics.clone() })
}
}

impl From<Ident> for PathSegment {
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ mod unquote;
use fm::FileId;
use iter_extended::vecmap;
use noirc_errors::{Location, Span, Spanned};
pub use path_resolution::Turbofish;
use path_resolution::{PathResolution, PathResolutionItem};
use types::bind_ordered_generics;

Expand Down
24 changes: 3 additions & 21 deletions compiler/noirc_frontend/src/elaborator/path_resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,7 @@ impl<'context> Elaborator<'context> {
ModuleDefId::TypeId(id) => (
id.module_id(),
true,
IntermediatePathResolutionItem::Struct(
id,
last_segment_generics.as_ref().map(|generics| Turbofish {
generics: generics.clone(),
span: last_segment.turbofish_span(),
}),
),
IntermediatePathResolutionItem::Struct(id, last_segment.turbofish()),
),
ModuleDefId::TypeAliasId(id) => {
let type_alias = self.interner.get_type_alias(id);
Expand All @@ -244,25 +238,13 @@ impl<'context> Elaborator<'context> {
(
module_id,
true,
IntermediatePathResolutionItem::TypeAlias(
id,
last_segment_generics.as_ref().map(|generics| Turbofish {
generics: generics.clone(),
span: last_segment.turbofish_span(),
}),
),
IntermediatePathResolutionItem::TypeAlias(id, last_segment.turbofish()),
)
}
ModuleDefId::TraitId(id) => (
id.0,
false,
IntermediatePathResolutionItem::Trait(
id,
last_segment_generics.as_ref().map(|generics| Turbofish {
generics: generics.clone(),
span: last_segment.turbofish_span(),
}),
),
IntermediatePathResolutionItem::Trait(id, last_segment.turbofish()),
),
ModuleDefId::FunctionId(_) => panic!("functions cannot be in the type namespace"),
ModuleDefId::GlobalId(_) => panic!("globals cannot be in the type namespace"),
Expand Down
103 changes: 85 additions & 18 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,60 @@ impl<'context> Elaborator<'context> {
None
}

/// This resolves a method in the form `Struct::method` where `method` is a trait method
fn resolve_struct_trait_method(&mut self, path: &Path) -> Option<TraitPathResolution> {
if path.segments.len() < 2 {
return None;
}

let mut path = path.clone();
let span = path.span();
let last_segment = path.pop();
let before_last_segment = path.last_segment();

let path_resolution = self.resolve_path(path).ok()?;
let PathResolutionItem::Struct(struct_id) = path_resolution.item else {
return None;
};

let struct_type = self.get_struct(struct_id);
let generics = struct_type.borrow().instantiate(self.interner);
let typ = Type::Struct(struct_type, generics);
let method_name = &last_segment.ident.0.contents;

// If we can find a method on the struct, this is definitely not a trait method
if self.interner.lookup_direct_method(&typ, method_name, false).is_some() {
return None;
}

let trait_methods = self.interner.lookup_trait_methods(&typ, method_name, false);
if trait_methods.is_empty() {
return None;
}

let (hir_method_reference, error) =
self.get_trait_method_in_scope(&trait_methods, method_name, last_segment.span);
let hir_method_reference = hir_method_reference?;
let func_id = hir_method_reference.func_id(self.interner)?;
let HirMethodReference::TraitMethodId(trait_method_id, _, _) = hir_method_reference else {
return None;
};

let trait_id = trait_method_id.trait_id;
let trait_ = self.interner.get_trait(trait_id);
let mut constraint = trait_.as_constraint(span);
constraint.typ = typ;

let method = TraitMethod { method_id: trait_method_id, constraint, assumed: false };
let turbofish = before_last_segment.turbofish();
let item = PathResolutionItem::TraitFunction(trait_id, turbofish, func_id);
let mut errors = path_resolution.errors;
if let Some(error) = error {
errors.push(error);
}
Some(TraitPathResolution { method, item: Some(item), errors })
}

// Try to resolve the given trait method path.
//
// Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not
Expand All @@ -695,6 +749,7 @@ impl<'context> Elaborator<'context> {
self.resolve_trait_static_method_by_self(path)
.or_else(|| self.resolve_trait_static_method(path))
.or_else(|| self.resolve_trait_method_by_named_generic(path))
.or_else(|| self.resolve_struct_trait_method(path))
}

pub(super) fn unify(
Expand Down Expand Up @@ -1456,6 +1511,19 @@ impl<'context> Elaborator<'context> {
method_name: &str,
span: Span,
) -> Option<HirMethodReference> {
let (method, error) = self.get_trait_method_in_scope(trait_methods, method_name, span);
if let Some(error) = error {
self.push_err(error);
}
method
}

fn get_trait_method_in_scope(
&mut self,
trait_methods: &[(FuncId, TraitId)],
method_name: &str,
span: Span,
) -> (Option<HirMethodReference>, Option<PathResolutionError>) {
let module_id = self.module_id();
let module_data = self.get_module(module_id);

Expand Down Expand Up @@ -1489,28 +1557,24 @@ impl<'context> Elaborator<'context> {
let trait_id = *traits.iter().next().unwrap();
let trait_ = self.interner.get_trait(trait_id);
let trait_name = self.fully_qualified_trait_path(trait_);

self.push_err(PathResolutionError::TraitMethodNotInScope {
let method =
self.trait_hir_method_reference(trait_id, trait_methods, method_name, span);
let error = PathResolutionError::TraitMethodNotInScope {
ident: Ident::new(method_name.into(), span),
trait_name,
});

return Some(self.trait_hir_method_reference(
trait_id,
trait_methods,
method_name,
span,
));
};
return (Some(method), Some(error));
} else {
let traits = vecmap(traits, |trait_id| {
let trait_ = self.interner.get_trait(trait_id);
self.fully_qualified_trait_path(trait_)
});
self.push_err(PathResolutionError::UnresolvedWithPossibleTraitsToImport {
let method = None;
let error = PathResolutionError::UnresolvedWithPossibleTraitsToImport {
ident: Ident::new(method_name.into(), span),
traits,
});
return None;
};
return (method, Some(error));
}
}

Expand All @@ -1519,15 +1583,18 @@ impl<'context> Elaborator<'context> {
let trait_ = self.interner.get_trait(trait_id);
self.fully_qualified_trait_path(trait_)
});
self.push_err(PathResolutionError::MultipleTraitsInScope {
let method = None;
let error = PathResolutionError::MultipleTraitsInScope {
ident: Ident::new(method_name.into(), span),
traits,
});
return None;
};
return (method, Some(error));
}

let trait_id = traits_in_scope[0].0;
Some(self.trait_hir_method_reference(trait_id, trait_methods, method_name, span))
let method = self.trait_hir_method_reference(trait_id, trait_methods, method_name, span);
let error = None;
(Some(method), error)
}

fn trait_hir_method_reference(
Expand All @@ -1545,7 +1612,7 @@ impl<'context> Elaborator<'context> {

// Return a TraitMethodId with unbound generics. These will later be bound by the type-checker.
let trait_ = self.interner.get_trait(trait_id);
let generics = trait_.as_constraint(span).trait_bound.trait_generics;
let generics = trait_.get_trait_generics(span);
let trait_method_id = trait_.find_method(method_name).unwrap();
HirMethodReference::TraitMethodId(trait_method_id, generics, false)
}
Expand Down
16 changes: 8 additions & 8 deletions compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,22 +186,22 @@ impl Trait {
(ordered, named)
}

/// Returns a TraitConstraint for this trait using Self as the object
/// type and the uninstantiated generics for any trait generics.
pub fn as_constraint(&self, span: Span) -> TraitConstraint {
pub fn get_trait_generics(&self, span: Span) -> TraitGenerics {
let ordered = vecmap(&self.generics, |generic| generic.clone().as_named_generic());
let named = vecmap(&self.associated_types, |generic| {
let name = Ident::new(generic.name.to_string(), span);
NamedType { name, typ: generic.clone().as_named_generic() }
});
TraitGenerics { ordered, named }
}

/// Returns a TraitConstraint for this trait using Self as the object
/// type and the uninstantiated generics for any trait generics.
pub fn as_constraint(&self, span: Span) -> TraitConstraint {
let trait_generics = self.get_trait_generics(span);
TraitConstraint {
typ: Type::TypeVariable(self.self_type_typevar.clone()),
trait_bound: ResolvedTraitBound {
trait_generics: TraitGenerics { ordered, named },
trait_id: self.id,
span,
},
trait_bound: ResolvedTraitBound { trait_generics, trait_id: self.id, span },
}
}
}
Expand Down
31 changes: 29 additions & 2 deletions compiler/noirc_frontend/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1341,9 +1341,7 @@ fn regression_6530() {
assert_eq!(errors.len(), 0);
}

// See https://github.com/noir-lang/noir/issues/7090
#[test]
#[should_panic]
fn calls_trait_method_using_struct_name_when_multiple_impls_exist() {
let src = r#"
trait From2<T> {
Expand All @@ -1367,3 +1365,32 @@ fn calls_trait_method_using_struct_name_when_multiple_impls_exist() {
"#;
assert_no_errors(src);
}

#[test]
fn calls_trait_method_using_struct_name_when_multiple_impls_exist_and_errors_turbofish() {
let src = r#"
trait From2<T> {
fn from2(input: T) -> Self;
}
struct U60Repr {}
impl From2<[Field; 3]> for U60Repr {
fn from2(_: [Field; 3]) -> Self {
U60Repr {}
}
}
impl From2<Field> for U60Repr {
fn from2(_: Field) -> Self {
U60Repr {}
}
}
fn main() {
let _ = U60Repr::<Field>::from2([1, 2, 3]);
}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);
assert!(matches!(
errors[0].0,
CompilationError::TypeError(TypeCheckError::TypeMismatch { .. })
));
}

1 comment on commit 966d8a6

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'Compilation Time'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.20.

Benchmark suite Current: 966d8a6 Previous: 521f5ce Ratio
rollup-root 4.366 s 3.61 s 1.21

This comment was automatically generated by workflow using github-action-benchmark.

CC: @TomAFrench

Please sign in to comment.