Skip to content

Commit

Permalink
fix: Fix canonicalization bug (#6033)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #6032

## Summary\*

This issue was because when we sort commutative terms we sorted them in
a BTreeSet, but this meant any identical terms were deduplicated. I
changed it to keep track of how many of each term there was so that we
don't accidentally change the expression when canonicalizing it.

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
jfecher authored Sep 13, 2024
1 parent ab203e4 commit 7397772
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
22 changes: 15 additions & 7 deletions compiler/noirc_frontend/src/hir_def/types/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::BTreeSet;
use std::collections::BTreeMap;

use crate::{BinaryTypeOperator, Type, TypeBindings, UnificationError};

Expand Down Expand Up @@ -52,7 +52,8 @@ impl Type {
fn sort_commutative(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Type {
let mut queue = vec![lhs.clone(), rhs.clone()];

let mut sorted = BTreeSet::new();
// Maps each term to the number of times that term was used.
let mut sorted = BTreeMap::new();

let zero_value = if op == BinaryTypeOperator::Addition { 0 } else { 1 };
let mut constant = zero_value;
Expand All @@ -68,20 +69,27 @@ impl Type {
if let Some(result) = op.function(constant, new_constant) {
constant = result;
} else {
sorted.insert(Type::Constant(new_constant));
*sorted.entry(Type::Constant(new_constant)).or_default() += 1;
}
}
other => {
sorted.insert(other);
*sorted.entry(other).or_default() += 1;
}
}
}

if let Some(first) = sorted.pop_first() {
let mut typ = first.clone();
let (mut typ, first_type_count) = first.clone();

for rhs in sorted {
typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone()));
// - 1 since `typ` already is set to the first instance
for _ in 0..first_type_count - 1 {
typ = Type::InfixExpr(Box::new(typ), op, Box::new(first.0.clone()));
}

for (rhs, rhs_count) in sorted {
for _ in 0..rhs_count {
typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone()));
}
}

if constant != zero_value {
Expand Down
19 changes: 19 additions & 0 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3463,6 +3463,25 @@ fn comptime_type_in_runtime_code() {
));
}

#[test]
fn arithmetic_generics_canonicalization_deduplication_regression() {
let source = r#"
struct ArrData<let N: u32> {
a: [Field; N],
b: [Field; N + N - 1],
}
fn main() {
let _f: ArrData<5> = ArrData {
a: [0; 5],
b: [0; 9],
};
}
"#;
let errors = get_program_errors(source);
assert_eq!(errors.len(), 0);
}

#[test]
fn cannot_mutate_immutable_variable() {
let src = r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,11 @@ fn demo_proof<let N: u32>() -> Equiv<W<(N * (N + 1))>, (Equiv<W<N>, (), W<N>, ()
let p1: Equiv<W<(N + 1) * N>, (), W<N * (N + 1)>, ()> = mul_comm();
let p2: Equiv<W<N * (N + 1)>, (), W<N * N + N>, ()> = mul_add::<N, N, 1>();
let p3_sub: Equiv<W<N>, (), W<N>, ()> = mul_one_r();
let p3: Equiv<W<N * N + N>, (), W<N * N + N>, ()> = add_equiv_r::<N * N, N, N, _, _>(p3_sub);
equiv_trans(equiv_trans(p1, p2), p3)
let _p3: Equiv<W<N * N + N>, (), W<N * N + N>, ()> = add_equiv_r::<N * N, N, N, _, _>(p3_sub);
let _p1_to_2 = equiv_trans(p1, p2);

// equiv_trans(p1_to_2, p3)
std::mem::zeroed()
}

fn test_constant_folding<let N: u32>() {
Expand Down

0 comments on commit 7397772

Please sign in to comment.