diff --git a/src/ast/pattern/pattern_inference.cpp b/src/ast/pattern/pattern_inference.cpp index d751a138801..e795188fb31 100644 --- a/src/ast/pattern/pattern_inference.cpp +++ b/src/ast/pattern/pattern_inference.cpp @@ -405,6 +405,44 @@ bool pattern_inference_cfg::pattern_weight_lt::operator()(expr * n1, expr * n2) return num_free_vars1 > num_free_vars2 || (num_free_vars1 == num_free_vars2 && i1.m_size < i2.m_size); } + +app* pattern_inference_cfg::mk_pattern(app* candidate) { + auto has_var_arg = [&](expr* e) { + if (!is_app(e)) + return false; + for (expr* arg : *to_app(e)) + if (is_var(arg)) + return true; + return false; + }; + if (has_var_arg(candidate)) + return m.mk_pattern(candidate); + m_args.reset(); + for (expr* arg : *candidate) { + if (!is_app(arg)) + return m.mk_pattern(candidate); + m_args.push_back(to_app(arg)); + } + for (unsigned i = 0; i < m_args.size(); ++i) { + app* arg = m_args[i]; + if (has_var_arg(arg)) + continue; + m_args[i] = m_args.back(); + --i; + m_args.pop_back(); + + if (is_ground(arg)) + continue; + + for (expr* e : *to_app(arg)) { + if (!is_app(e)) + return m.mk_pattern(candidate); + m_args.push_back(to_app(e)); + } + } + return m.mk_pattern(m_args.size(), (app* const*)m_args.data()); +} + /** \brief Create unary patterns (single expressions that contain all bound variables). If a candidate does not contain all bound @@ -418,7 +456,7 @@ void pattern_inference_cfg::candidates2unary_patterns(ptr_vector const & ca expr2info::obj_map_entry * e = m_candidates_info.find_core(candidate); info const & i = e->get_data().m_value; if (i.m_free_vars.num_elems() == m_num_bindings) { - app * new_pattern = m.mk_pattern(candidate); + app * new_pattern = mk_pattern(candidate); result.push_back(new_pattern); } else { diff --git a/src/ast/pattern/pattern_inference.h b/src/ast/pattern/pattern_inference.h index bb4cf423833..da905dca412 100644 --- a/src/ast/pattern/pattern_inference.h +++ b/src/ast/pattern/pattern_inference.h @@ -188,6 +188,9 @@ class pattern_inference_cfg : public default_rewriter_cfg { ptr_vector m_pre_patterns; expr_pattern_match m_database; + ptr_buffer m_args; + app* mk_pattern(app* candidate); + void candidates2unary_patterns(ptr_vector const & candidate_patterns, ptr_vector & remaining_candidate_patterns, app_ref_buffer & result); diff --git a/src/smt/smt_case_split_queue.cpp b/src/smt/smt_case_split_queue.cpp index 9f5bdb0d433..711e3de3a04 100644 --- a/src/smt/smt_case_split_queue.cpp +++ b/src/smt/smt_case_split_queue.cpp @@ -964,7 +964,7 @@ namespace { } void display(std::ostream & out) override { - if (m_queue.empty() && m_queue2.empty()) + if (m_queue.empty()) return; out << "case-splits:\n"; display_core(out, m_queue, m_head, 1);