Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: expr_eq_fn #4934

Merged
merged 4 commits into from
Aug 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions src/kernel/expr_eq_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Author: Leonardo de Moura
*/
#include <vector>
#include <memory>
#include "runtime/alloc.h"
#include "runtime/interrupt.h"
#include "runtime/thread.h"
#include "kernel/expr.h"
Expand All @@ -27,6 +28,8 @@ class expr_eq_fn {
};
typedef std::unordered_set<std::pair<lean_object *, lean_object *>, key_hasher> cache;
cache * m_cache = nullptr;
size_t m_max_stack_depth = 0;
size_t m_counter = 0;
bool check_cache(expr const & a, expr const & b) {
if (!is_shared(a) || !is_shared(b))
return false;
Expand All @@ -38,10 +41,17 @@ class expr_eq_fn {
m_cache->insert(key);
return false;
}
static void check_system() {
::lean::check_system("expression equality test");
void check_system(unsigned depth) {
/*
We used to use `lean::check_system` here. We claim it is ok to not check memory consumption here.
Note that `do_check_interrupted` was set to `false`. Thus, `check_interrupted` and `check_heartbeat` were not being used.
*/
if (depth > m_max_stack_depth) {
if (m_max_stack_depth > 0)
throw stack_space_exception("expression equality test");
}
}
bool apply(expr const & a, expr const & b, bool root = false) {
bool apply(expr const & a, expr const & b, unsigned depth, bool root = false) {
if (is_eqp(a, b)) return true;
if (hash(a) != hash(b)) return false;
if (a.kind() != b.kind()) return false;
Expand All @@ -53,13 +63,18 @@ class expr_eq_fn {
case expr_kind::Sort: return sort_level(a) == sort_level(b);
default: break;
}
if (!root && check_cache(a, b))
if (root) {
m_max_stack_depth = get_available_stack_size() / 256;
} else if (check_cache(a, b)) {
return true;
}
/*
We increase the number of heartbeats here because some code (e.g., `simp`) may spend a lot of time comparing
`Expr`s (e.g., checking a cache with many collisions) without allocating any significant amount of memory.
We use the counter to invoke `add_heartbeats` later. Reason: heartbeat is a thread local storage, and morexpensive to update.
*/
lean_inc_heartbeat();
m_counter++;
depth++;
switch (a.kind()) {
case expr_kind::BVar:
case expr_kind::Lit:
Expand All @@ -69,43 +84,55 @@ class expr_eq_fn {
lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::MData:
return
apply(mdata_expr(a), mdata_expr(b)) &&
apply(mdata_expr(a), mdata_expr(b), depth) &&
mdata_data(a) == mdata_data(b);
case expr_kind::Proj:
return
apply(proj_expr(a), proj_expr(b)) &&
apply(proj_expr(a), proj_expr(b), depth) &&
proj_sname(a) == proj_sname(b) &&
proj_idx(a) == proj_idx(b);
case expr_kind::Const:
return
const_name(a) == const_name(b) &&
compare(const_levels(a), const_levels(b), [](level const & l1, level const & l2) { return l1 == l2; });
case expr_kind::App:
check_system();
return
apply(app_fn(a), app_fn(b)) &&
apply(app_arg(a), app_arg(b));
case expr_kind::App: {
check_system(depth);
if (!apply(app_arg(a), app_arg(b), depth)) return false;
expr const * curr_a = &app_fn(a);
expr const * curr_b = &app_fn(b);
while (true) {
if (!is_app(*curr_a)) break;
if (!is_app(*curr_b)) return false;
if (!apply(app_arg(*curr_a), app_arg(*curr_b), depth)) return false;
curr_a = &app_fn(*curr_a);
curr_b = &app_fn(*curr_b);
}
return apply(*curr_a, *curr_b, depth);
}
case expr_kind::Lambda: case expr_kind::Pi:
check_system();
check_system(depth);
return
apply(binding_domain(a), binding_domain(b)) &&
apply(binding_body(a), binding_body(b)) &&
apply(binding_domain(a), binding_domain(b), depth) &&
apply(binding_body(a), binding_body(b), depth) &&
(!CompareBinderInfo || binding_name(a) == binding_name(b)) &&
(!CompareBinderInfo || binding_info(a) == binding_info(b));
case expr_kind::Let:
check_system();
check_system(depth);
return
apply(let_type(a), let_type(b)) &&
apply(let_value(a), let_value(b)) &&
apply(let_body(a), let_body(b)) &&
apply(let_type(a), let_type(b), depth) &&
apply(let_value(a), let_value(b), depth) &&
apply(let_body(a), let_body(b), depth) &&
(!CompareBinderInfo || let_name(a) == let_name(b));
}
lean_unreachable(); // LCOV_EXCL_LINE
}
public:
expr_eq_fn() {}
~expr_eq_fn() { if (m_cache) delete m_cache; }
bool operator()(expr const & a, expr const & b) { return apply(a, b, true); }
~expr_eq_fn() {
if (m_cache) delete m_cache;
if (m_counter > 0) add_heartbeats(m_counter);
}
bool operator()(expr const & a, expr const & b) { return apply(a, b, 0, true); }
};

bool is_equal(expr const & a, expr const & b) {
Expand Down
Loading