From 4b6d7ca0972574ac3989143ca48771d9f28f80a6 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 25 Jan 2021 17:54:53 -0800 Subject: [PATCH] working on mam --- src/ast/euf/CMakeLists.txt | 1 - src/ast/euf/euf_egraph.cpp | 21 + src/ast/euf/euf_egraph.h | 5 + src/ast/euf/euf_enode.h | 14 + src/ast/euf/euf_etable.h | 2 +- src/ast/for_each_expr.cpp | 2 +- src/ast/for_each_expr.h | 2 +- src/sat/smt/CMakeLists.txt | 2 + src/sat/smt/q_ematch.cpp | 445 ++++++++++++++++++ src/sat/smt/q_ematch.h | 138 ++++++ .../euf/euf_mam.cpp => sat/smt/q_mam.cpp} | 338 ++++++------- src/{ast/euf/euf_mam.h => sat/smt/q_mam.h} | 33 +- src/sat/smt/q_solver.h | 3 + src/sat/smt/sat_th.h | 2 +- src/util/nat_set.h | 8 +- 15 files changed, 807 insertions(+), 209 deletions(-) create mode 100644 src/sat/smt/q_ematch.cpp create mode 100644 src/sat/smt/q_ematch.h rename src/{ast/euf/euf_mam.cpp => sat/smt/q_mam.cpp} (94%) rename src/{ast/euf/euf_mam.h => sat/smt/q_mam.h} (58%) diff --git a/src/ast/euf/CMakeLists.txt b/src/ast/euf/CMakeLists.txt index 5706da8d8..8d3fa2e74 100644 --- a/src/ast/euf/CMakeLists.txt +++ b/src/ast/euf/CMakeLists.txt @@ -3,7 +3,6 @@ z3_add_component(euf euf_enode.cpp euf_etable.cpp euf_egraph.cpp - euf_mam.cpp COMPONENT_DEPENDENCIES ast util diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index f7edfcbfa..537176ecb 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -39,6 +39,23 @@ namespace euf { return n; } + enode* egraph::find(expr* e, unsigned n, enode* const* args) { + if (m_tmp_node && m_tmp_node_capacity < n) { + memory::deallocate(m_tmp_node); + m_tmp_node = nullptr; + } + if (!m_tmp_node) { + m_tmp_node = enode::mk_tmp(n); + m_tmp_node_capacity = n; + } + for (unsigned i = 0; i < n; ++i) + m_tmp_node->m_args[i] = args[i]; + m_tmp_node->m_num_args = n; + m_tmp_node->m_expr = e; + return m_table.find(m_tmp_node); + } + + enode_vector const& egraph::enodes_of(func_decl* f) { unsigned id = f->get_decl_id(); if (id < m_decl2enodes.size()) @@ -115,6 +132,8 @@ namespace euf { egraph::~egraph() { for (enode* n : m_nodes) n->m_parents.finalize(); + if (m_tmp_node) + memory::deallocate(m_tmp_node); } void egraph::add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) { @@ -382,6 +401,8 @@ namespace euf { r2->inc_class_size(r1->class_size()); merge_th_eq(r1, r2); reinsert_parents(r1, r2); + if (m_on_merge) + m_on_merge(r2, r1); } void egraph::remove_parents(enode* r1, enode* r2) { diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index ee7d14a84..2feb9e79f 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -148,6 +148,8 @@ namespace euf { unsigned_vector m_scopes; enode_vector m_expr2enode; enode* m_tmp_eq { nullptr }; + enode* m_tmp_node { nullptr }; + unsigned m_tmp_node_capacity { 0 }; enode_vector m_nodes; expr_ref_vector m_exprs; vector m_decl2enodes; @@ -165,6 +167,7 @@ namespace euf { enode_vector m_todo; stats m_stats; bool m_uses_congruence { false }; + std::function m_on_merge; std::function m_used_eq; std::function m_used_cc; std::function m_display_justification; @@ -218,6 +221,7 @@ namespace euf { egraph(ast_manager& m); ~egraph(); enode* find(expr* f) const { return m_expr2enode.get(f->get_id(), nullptr); } + enode* find(expr* f, unsigned n, enode* const* args); enode* mk(expr* f, unsigned generation, unsigned n, enode *const* args); enode_vector const& enodes_of(func_decl* f); void push() { ++m_num_scopes; } @@ -269,6 +273,7 @@ namespace euf { void set_value(enode* n, lbool value); void set_bool_var(enode* n, unsigned v) { n->set_bool_var(v); } + void set_on_merge(std::function& on_merge) { m_on_merge = on_merge; } void set_used_eq(std::function& used_eq) { m_used_eq = used_eq; } void set_used_cc(std::function& used_cc) { m_used_cc = used_cc; } void set_display_justification(std::function & d) { m_display_justification = d; } diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 447e6bcaf..17136d31d 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -109,6 +109,20 @@ namespace euf { n->m_args[i] = nullptr; return n; } + + static enode* mk_tmp(unsigned num_args) { + void* mem = memory::allocate(get_enode_size(num_args)); + enode* n = new (mem) enode(); + n->m_expr = nullptr; + n->m_next = n; + n->m_root = n; + n->m_commutative = true; + n->m_num_args = 2; + n->m_merge_enabled = true; + for (unsigned i = 0; i < num_args; ++i) + n->m_args[i] = nullptr; + return n; + } void set_update_children() { m_update_children = true; } diff --git a/src/ast/euf/euf_etable.h b/src/ast/euf/euf_etable.h index 6e88cf9b9..6274af142 100644 --- a/src/ast/euf/euf_etable.h +++ b/src/ast/euf/euf_etable.h @@ -83,7 +83,7 @@ namespace euf { struct cg_comm_eq { bool & m_commutativity; - cg_comm_eq( bool & c): m_commutativity(c) {} + cg_comm_eq(bool & c): m_commutativity(c) {} bool operator()(enode * n1, enode * n2) const { SASSERT(n1->num_args() == 2); SASSERT(n2->num_args() == 2); diff --git a/src/ast/for_each_expr.cpp b/src/ast/for_each_expr.cpp index 75d62161a..480a66d82 100644 --- a/src/ast/for_each_expr.cpp +++ b/src/ast/for_each_expr.cpp @@ -111,7 +111,7 @@ bool subterms::iterator::operator!=(iterator const& other) const { subterms_postorder::subterms_postorder(expr_ref_vector const& es): m_es(es) {} -subterms_postorder::subterms_postorder(expr_ref& e) : m_es(e.m()) { m_es.push_back(e); } +subterms_postorder::subterms_postorder(expr_ref const& e) : m_es(e.m()) { m_es.push_back(e); } subterms_postorder::iterator subterms_postorder::begin() { return iterator(*this, true); } subterms_postorder::iterator subterms_postorder::end() { return iterator(*this, false); } subterms_postorder::iterator::iterator(subterms_postorder& f, bool start): m_es(f.m_es) { diff --git a/src/ast/for_each_expr.h b/src/ast/for_each_expr.h index 5c4f95403..d58dcccfd 100644 --- a/src/ast/for_each_expr.h +++ b/src/ast/for_each_expr.h @@ -203,7 +203,7 @@ public: bool operator!=(iterator const& other) const; }; subterms_postorder(expr_ref_vector const& es); - subterms_postorder(expr_ref& e); + subterms_postorder(expr_ref const& e); iterator begin(); iterator end(); }; diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 6eb4f429d..c81a75283 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -29,6 +29,8 @@ z3_add_component(sat_smt euf_relevancy.cpp euf_solver.cpp fpa_solver.cpp + q_ematch.cpp + q_mam.cpp q_mbi.cpp q_model_fixer.cpp q_solver.cpp diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp new file mode 100644 index 000000000..66a4c3fd4 --- /dev/null +++ b/src/sat/smt/q_ematch.cpp @@ -0,0 +1,445 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + q_ematch.cpp + +Abstract: + + E-matching quantifier instantiation plugin + +Author: + + Nikolaj Bjorner (nbjorner) 2021-01-24 + +Todo: + +- clausify +- propagate without instantiations, produce explanations for eval +- generations +- insert instantiations into priority queue +- cache instantiations and substitutions +- nested quantifiers +- non-cnf quantifiers + +--*/ + +#include "ast/rewriter/var_subst.h" +#include "solver/solver.h" +#include "sat/smt/sat_th.h" +#include "sat/smt/euf_solver.h" +#include "sat/smt/q_solver.h" +#include "sat/smt/q_mam.h" +#include "sat/smt/q_ematch.h" + + +namespace q { + + ematch::ematch(euf::solver& ctx, solver& s): + ctx(ctx), + m_qs(s), + m(ctx.get_manager()) + { + std::function _on_merge = + [&](euf::enode* root, euf::enode* other) { + on_merge(root, other); + }; + ctx.get_egraph().set_on_merge(_on_merge); + m_mam = mam::mk(ctx, *this); + } + + void ematch::ensure_ground_enodes(expr* e) { + mam::ground_subterms(e, m_ground); + for (expr* g : m_ground) + m_qs.e_internalize(g); + } + + void ematch::ensure_ground_enodes(clause const& c) { + quantifier* q = c.m_q; + unsigned num_patterns = q->get_num_patterns(); + for (unsigned i = 0; i < num_patterns; i++) + ensure_ground_enodes(q->get_pattern(i)); + for (auto lit : c.m_lits) { + ensure_ground_enodes(lit.lhs); + ensure_ground_enodes(lit.rhs); + } + } + + struct restore_watch : public trail { + vector& v; + unsigned idx, offset; + restore_watch(vector& v, unsigned idx): + v(v), idx(idx), offset(v[idx].size()) {} + void undo(euf::solver& ctx) override { + v[idx].shrink(offset); + } + }; + + void ematch::on_merge(euf::enode* root, euf::enode* other) { + SASSERT(root->get_root() == other->get_root()); + unsigned root_id = root->get_expr_id(); + unsigned other_id = other->get_expr_id(); + m_watch.reserve(std::max(root_id, other_id) + 1); + + insert_to_propagate(root_id); + insert_to_propagate(other_id); + + if (!m_watch[other_id].empty()) { + ctx.push(restore_watch(m_watch, root_id)); + m_watch[root_id].append(m_watch[other_id]); + } + + m_mam->on_merge(root, other); + if (m_lazy_mam) + m_lazy_mam->on_merge(root, other); + } + + // watch only nodes introduced in bindings or ground arguments of functions + // or functions that have been inserted. + + void ematch::add_watch(euf::enode* n, unsigned idx) { + unsigned root_id = n->get_root_id(); + m_watch.reserve(root_id + 1); + ctx.push(restore_watch(m_watch, root_id)); + m_watch[root_id].push_back(idx); + } + + void ematch::init_watch(expr* e, unsigned clause_idx) { + ptr_buffer todo; + m_mark.reset(); + todo.push_back(e); + while (!todo.empty()) { + expr* t = todo.back(); + if (m_mark.is_marked(t)) + continue; + todo.pop_back(); + m_mark.mark(t); + if (is_ground(t)) { + add_watch(ctx.get_egraph().find(t), clause_idx); + continue; + } + if (!is_app(t)) + continue; + for (expr* arg : *to_app(t)) + todo.push_back(arg); + } + } + + void ematch::init_watch(clause& c, unsigned idx) { + for (auto lit : c.m_lits) { + if (!is_ground(lit.lhs)) + init_watch(lit.lhs, idx); + if (!is_ground(lit.rhs)) + init_watch(lit.rhs, idx); + } + } + + ematch::binding* ematch::alloc_binding(unsigned n) { + unsigned sz = sizeof(binding) + sizeof(euf::enode* const*)*n; + void* mem = ctx.get_region().allocate(sz); + return new (mem) binding(); + } + + void ematch::on_binding(quantifier* q, app* pat, euf::enode* const* _binding) { + clause& c = *m_clauses[m_q2clauses[q]]; + if (propagate(_binding, c)) + return; + unsigned n = q->get_num_decls(); + binding* b = alloc_binding(n); + b->m_propagated = false; + for (unsigned i = 0; i < n; ++i) + b->m_nodes[i] = _binding[i]; + c.m_bindings.push_back(b); + ctx.push(push_back_vector>(c.m_bindings)); + } + + bool ematch::propagate(euf::enode* const* binding, clause& c) { + unsigned clause_idx = m_q2clauses[c.m_q]; + struct scoped_reset { + ematch& e; + scoped_reset(ematch& e): e(e) { e.m_mark.reset(); } + ~scoped_reset() { e.m_mark.reset(); } + }; + scoped_reset _sr(*this); + + unsigned idx = UINT_MAX; + for (unsigned i = c.m_lits.size(); i-- > 0; ) { + lit l = c.m_lits[i]; + m_indirect_nodes.reset(); + lbool cmp = compare(binding, l.lhs, l.rhs); + switch (cmp) { + case l_false: + if (l.sign) { + if (i > 0) + std::swap(c.m_lits[0], c.m_lits[i]); + return true; + } + break; + case l_true: + if (!l.sign) { + if (i > 0) + std::swap(c.m_lits[0], c.m_lits[i]); + return true; + } + break; + case l_undef: + if (idx == 0) { + // attach bindings and indirect nodes + // to watch + for (euf::enode* n : m_indirect_nodes) + add_watch(n, clause_idx); + for (unsigned j = c.m_q->get_num_decls(); j-- > 0; ) + add_watch(binding[j], clause_idx); + if (i > 1) + std::swap(c.m_lits[1], c.m_lits[i]); + return false; + } + else if (i > 0) { + std::swap(c.m_lits[0], c.m_lits[i]); + idx = 0; + } + break; + } + } + if (idx == UINT_MAX) { + std::cout << "clause is false\n"; + } + else { + std::cout << "unit propagate\n"; + } + instantiate(binding, c); + return true; + } + + // vanilla instantiation method. + void ematch::instantiate(euf::enode* const* binding, clause& c) { + expr_ref_vector _binding(m); + quantifier* q = c.m_q; + for (unsigned i = 0; i < q->get_num_decls(); ++i) + _binding.push_back(binding[i]->get_expr()); + var_subst subst(m); + expr_ref result = subst(q->get_expr(), _binding); + if (is_forall(q)) + m_qs.add_clause(~ctx.mk_literal(q), ctx.mk_literal(result)); + else + m_qs.add_clause(ctx.mk_literal(q), ~ctx.mk_literal(result)); + } + + lbool ematch::compare(euf::enode* const* binding, expr* s, expr* t) { + euf::enode* sn = eval(binding, s); + euf::enode* tn = eval(binding, t); + lbool c; + if (sn && sn == tn) + return l_true; + if (sn && tn && ctx.get_egraph().are_diseq(sn, tn)) + return l_false; + if (sn && tn) + return l_undef; + if (!sn && !tn) + return compare_rec(binding, s, t); + if (!sn && tn) + for (euf::enode* t1 : euf::enode_class(tn)) + if (c = compare_rec(binding, s, t1->get_expr()), c != l_undef) + return c; + if (sn && !tn) + for (euf::enode* s1 : euf::enode_class(sn)) + if (c = compare_rec(binding, t, s1->get_expr()), c != l_undef) + return c; + return l_undef; + } + + // f(p1) = f(p2) if p1 = p2 + // f(p1) != f(p2) if p1 != p2 and f is injective + lbool ematch::compare_rec(euf::enode* const* binding, expr* s, expr* t) { + if (m.are_equal(s, t)) + return l_true; + if (m.are_distinct(s, t)) + return l_false; + if (!is_app(s) || !is_app(t)) + return l_undef; + if (to_app(s)->get_decl() != to_app(t)->get_decl()) + return l_undef; + if (to_app(s)->get_num_args() != to_app(t)->get_num_args()) + return l_undef; + bool is_injective = to_app(s)->get_decl()->is_injective(); + bool has_undef = false; + for (unsigned i = to_app(s)->get_num_args(); i-- > 0; ) { + switch (compare(binding, to_app(s)->get_arg(i), to_app(t)->get_arg(i))) { + case l_true: + break; + case l_false: + if (is_injective) + return l_false; + return l_undef; + case l_undef: + if (!is_injective) + return l_undef; + has_undef = true; + break; + } + } + return has_undef ? l_undef : l_true; + } + + euf::enode* ematch::eval(euf::enode* const* binding, expr* e) { + if (is_ground(e)) + ctx.get_egraph().find(e)->get_root(); + if (m_mark.is_marked(e)) + return m_eval[e->get_id()]; + ptr_buffer todo; + ptr_buffer args; + todo.push_back(e); + while (!todo.empty()) { + expr* t = todo.back(); + SASSERT(!is_ground(t) || ctx.get_egraph().find(t)); + if (is_ground(t)) { + m_eval.setx(t->get_id(), ctx.get_egraph().find(t), nullptr); + SASSERT(m_eval[t->get_id()]); + todo.pop_back(); + continue; + } + if (m_mark.is_marked(t)) { + todo.pop_back(); + continue; + } + if (is_var(t)) { + m_mark.mark(t); + m_eval.setx(t->get_id(), binding[to_var(t)->get_idx()], nullptr); + todo.pop_back(); + continue; + } + if (!is_app(t)) + return nullptr; + args.reset(); + for (expr* arg : *to_app(t)) { + if (m_mark.is_marked(arg)) + args.push_back(m_eval[t->get_id()]); + else + todo.push_back(arg); + } + if (args.size() == to_app(t)->get_num_args()) { + euf::enode* n = ctx.get_egraph().find(t, args.size(), args.c_ptr()); + if (!n) + return nullptr; + m_indirect_nodes.push_back(n); + m_eval.setx(t->get_id(), n->get_root(), nullptr); + m_mark.mark(t); + todo.pop_back(); + } + } + return m_eval[e->get_id()]->get_root(); + } + + void ematch::insert_to_propagate(unsigned node_id) { + if (m_node_in_queue.contains(node_id)) + return; + m_node_in_queue.insert(node_id); + for (unsigned idx : m_watch[node_id]) { + if (!m_clause_in_queue.contains(idx)) { + m_clause_in_queue.insert(idx); + m_queue.push_back(idx); + } + } + } + + bool ematch::propagate() { + m_mam->propagate(); + if (m_qhead >= m_queue.size()) + return false; + bool propagated = false; + ctx.push(value_trail(m_qhead)); + for (; m_qhead < m_queue.size(); ++m_qhead) { + unsigned idx = m_queue[m_qhead]; + clause& c = *m_clauses[idx]; + for (auto& b : c.bindings()) { + if (!b->propagated() && propagate(b->m_nodes, c)) { + ctx.push(value_trail(b->m_propagated)); + b->set_propagated(true); + propagated = true; + } + } + } + m_clause_in_queue.reset(); + m_node_in_queue.reset(); + return propagated; + } + + ematch::clause* ematch::clausify(quantifier* q) { + NOT_IMPLEMENTED_YET(); + return nullptr; + } + + /** + * Attach ground subterms of patterns so they appear shared. + */ + void ematch::attach_ground_pattern_terms(expr* pat) { + mam::ground_subterms(pat, m_ground); + for (expr* g : m_ground) { + euf::enode* n = ctx.get_egraph().find(g); + if (!n->is_attached_to(m_qs.get_id())) { + euf::theory_var v = m_qs.mk_var(n); + ctx.get_egraph().add_th_var(n, v, m_qs.get_id()); + } + } + } + + struct ematch::pop_clause : public trail { + ematch& em; + pop_clause(ematch& em): em(em) {} + void undo(euf::solver& ctx) override { + em.m_q2clauses.remove(em.m_clauses.back()->m_q); + em.m_clauses.pop_back(); + } + }; + + void ematch::add(quantifier* q) { + clause* c = clausify(q); + ensure_ground_enodes(*c); + unsigned idx = m_clauses.size(); + m_clauses.push_back(c); + m_q2clauses.insert(q, idx); + ctx.push(pop_clause(*this)); + init_watch(*c, idx); + + bool has_unary_pattern = false; + unsigned num_patterns = q->get_num_patterns(); + for (unsigned i = 0; i < num_patterns && !has_unary_pattern; i++) + has_unary_pattern = (1 == to_app(q->get_pattern(i))->get_num_args()); + unsigned num_eager_multi_patterns = ctx.get_config().m_qi_max_eager_multipatterns; + if (!has_unary_pattern) + num_eager_multi_patterns++; + for (unsigned i = 0, j = 0; i < num_patterns; i++) { + app * mp = to_app(q->get_pattern(i)); + SASSERT(m.is_pattern(mp)); + bool unary = (mp->get_num_args() == 1); + TRACE("quantifier", tout << "adding:\n" << expr_ref(mp, m) << "\n";); + if (!unary && j >= num_eager_multi_patterns) { + TRACE("quantifier", tout << "delaying (too many multipatterns):\n" << mk_ismt2_pp(mp, m) << "\n";); + if (!m_lazy_mam) + m_lazy_mam = mam::mk(ctx, *this); + m_lazy_mam->add_pattern(q, mp); + } + else + m_mam->add_pattern(q, mp); + + attach_ground_pattern_terms(mp); + + if (!unary) + j++; + } + } + + bool ematch::operator()() { + if (m_lazy_mam) + m_lazy_mam->propagate(); + if (propagate()) + return true; + // + // TODO: loop over pending bindings and instantiate them + // + NOT_IMPLEMENTED_YET(); + return false; + } + +} diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h new file mode 100644 index 000000000..a4505e3a7 --- /dev/null +++ b/src/sat/smt/q_ematch.h @@ -0,0 +1,138 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + q_ematch.h + +Abstract: + + E-matching quantifier instantiation plugin + +Author: + + Nikolaj Bjorner (nbjorner) 2021-01-24 + +--*/ +#pragma once + +#include "util/nat_set.h" +#include "solver/solver.h" +#include "sat/smt/sat_th.h" +#include "sat/smt/q_mam.h" + +namespace euf { + class solver; +} + +namespace q { + + class solver; + + class ematch { + struct stats { + unsigned m_num_instantiations; + + stats() { reset(); } + + void reset() { + memset(this, 0, sizeof(*this)); + } + }; + + + struct lit { + expr_ref lhs; + expr_ref rhs; + bool sign; + lit(expr_ref& lhs, expr_ref& rhs, bool sign): + lhs(lhs), rhs(rhs), sign(sign) {} + + }; + + struct binding { + bool m_propagated { false }; + euf::enode* m_nodes[0]; + + binding() {} + + bool propagated() const { return m_propagated; } + void set_propagated(bool b) { m_propagated = b; } + euf::enode* const* nodes() { return m_nodes; } + }; + + binding* alloc_binding(unsigned n); + + struct clause { + vector m_lits; + quantifier* m_q; + ptr_vector m_bindings; + + ptr_vector const& bindings() { return m_bindings; } + }; + + struct pop_clause; + + euf::solver& ctx; + solver& m_qs; + ast_manager& m; + scoped_ptr m_mam, m_lazy_mam; + ptr_vector m_clauses; + obj_map m_q2clauses; + vector m_watch; // expr_id -> clause-index* + stats m_stats; + expr_fast_mark1 m_mark; + + nat_set m_node_in_queue; + nat_set m_clause_in_queue; + unsigned m_qhead { 0 }; + unsigned_vector m_queue; + + ptr_vector m_ground; + void ensure_ground_enodes(expr* e); + void ensure_ground_enodes(clause const& c); + + // compare s, t modulo sign under binding + lbool compare(euf::enode* const* binding, expr* s, expr* t); + lbool compare_rec(euf::enode* const* binding, expr* s, expr* t); + euf::enode_vector m_eval, m_indirect_nodes; + euf::enode* eval(euf::enode* const* binding, expr* e); + + bool propagate(euf::enode* const* binding, clause& c); + void instantiate(euf::enode* const* binding, clause& c); + + // register as callback into egraph. + void on_merge(euf::enode* root, euf::enode* other); + void insert_to_propagate(unsigned node_id); + + void add_watch(euf::enode* root, unsigned clause_idx); + void init_watch(expr* e, unsigned clause_idx); + void init_watch(clause& c, unsigned idx); + + // extract explanation + void get_antecedents(euf::enode* const* binding, unsigned clause_idx, bool probing); + + void attach_ground_pattern_terms(expr* pat); + clause* clausify(quantifier* q); + + + public: + + ematch(euf::solver& ctx, solver& s); + + bool operator()(); + + bool propagate(); + + void init_search(); + + void add(quantifier* q); + + void collect_statistics(statistics& st) const; + + // callback from mam + void on_binding(quantifier* q, app* pat, euf::enode* const* binding); + + }; + +} diff --git a/src/ast/euf/euf_mam.cpp b/src/sat/smt/q_mam.cpp similarity index 94% rename from src/ast/euf/euf_mam.cpp rename to src/sat/smt/q_mam.cpp index 2bde17855..813100cc5 100644 --- a/src/ast/euf/euf_mam.cpp +++ b/src/sat/smt/q_mam.cpp @@ -28,9 +28,11 @@ Revision History: #include "ast/ast_pp.h" #include "ast/ast_ll_pp.h" #include "ast/ast_smt2_pp.h" -#include "ast/euf/euf_mam.h" #include "ast/euf/euf_enode.h" #include "ast/euf/euf_egraph.h" +#include "sat/smt/q_mam.h" +#include "sat/smt/q_ematch.h" +#include "sat/smt/euf_solver.h" @@ -60,7 +62,7 @@ Revision History: #define IS_CGR_SUPPORT true -namespace euf { +namespace q { // ------------------------------------ // // Trail @@ -71,12 +73,11 @@ namespace euf { typedef trail_stack mam_trail_stack; - typedef trail mam_trail; template - class mam_value_trail : public value_trail { + class mam_value_trail : public value_trail { public: - mam_value_trail(T & value):value_trail(value) {} + mam_value_trail(T & value):value_trail(value) {} }; unsigned get_max_generation(unsigned n, enode* const* nodes) { @@ -458,7 +459,7 @@ namespace euf { } void display_label_hashes_core(std::ostream & out, app * p) const { - if (p->is_ground()) { + if (is_ground(p)) { enode * e = get_enode(*m_egraph, p); SASSERT(e->has_lbl_hash()); out << "#" << e->get_expr_id() << ":" << e->get_lbl_hash() << " "; @@ -606,8 +607,8 @@ namespace euf { // ------------------------------------ class code_tree_manager { + euf::solver & ctx; label_hasher & m_lbl_hasher; - mam_trail_stack & m_trail_stack; region & m_region; template @@ -640,10 +641,10 @@ namespace euf { } public: - code_tree_manager(label_hasher & h, mam_trail_stack & s): + code_tree_manager(label_hasher & h, euf::solver& ctx): + ctx(ctx), m_lbl_hasher(h), - m_trail_stack(s), - m_region(s.get_region()) { + m_region(ctx.get_region()) { } code_tree * mk_code_tree(func_decl * lbl, unsigned short num_args, bool filter_candidates) { @@ -765,20 +766,20 @@ namespace euf { } void set_next(instruction * instr, instruction * new_next) { - m_trail_stack.push(mam_value_trail(instr->m_next)); + ctx.push(mam_value_trail(instr->m_next)); instr->m_next = new_next; } void save_num_regs(code_tree * tree) { - m_trail_stack.push(mam_value_trail(tree->m_num_regs)); + ctx.push(mam_value_trail(tree->m_num_regs)); } void save_num_choices(code_tree * tree) { - m_trail_stack.push(mam_value_trail(tree->m_num_choices)); + ctx.push(mam_value_trail(tree->m_num_choices)); } void insert_new_lbl_hash(filter * instr, unsigned h) { - m_trail_stack.push(mam_value_trail(instr->m_lbl_set)); + ctx.push(mam_value_trail(instr->m_lbl_set)); instr->m_lbl_set.insert(h); } }; @@ -947,7 +948,7 @@ namespace euf { if (to_app(p)->is_ground()) { // ground applications are viewed as constants, and eagerly // converted into enodes. - enode * e = mk_enode(m_egraph, m_qa, to_app(p)); + enode * e = m_egraph.find(p); m_seq.push_back(m_ct_manager.mk_check(reg, e)); set_check_mark(reg, NOT_CHECKED); // reset mark, register was fully processed. continue; @@ -1071,7 +1072,7 @@ namespace euf { if (is_ground(n)) { unsigned oreg = m_tree->m_num_regs; m_tree->m_num_regs += 1; - enode * e = mk_enode(m_egraph, m_qa, n); + enode * e = m_egraph.find(n); m_seq.push_back(m_ct_manager.mk_get_enode(oreg, e)); return oreg; } @@ -1168,8 +1169,8 @@ namespace euf { SASSERT(is_app(curr)); - if (to_app(curr)->is_ground()) { - enode * e = mk_enode(m_egraph, m_qa, to_app(curr)); + if (is_ground(curr)) { + enode * e = m_egraph.find(curr); joints.push_back(TAG(enode *, e, GROUND_TERM_TAG)); continue; } @@ -1310,20 +1311,20 @@ namespace euf { unsigned reg1 = instr->m_reg1; unsigned reg2 = instr->m_reg2; return - m_registers[reg1] != 0 && + m_registers[reg1] != nullptr && m_registers[reg1] == m_registers[reg2]; } bool is_compatible(check * instr) const { unsigned reg = instr->m_reg; enode * n = instr->m_enode; - if (m_registers[reg] == 0) + if (!m_registers[reg]) return false; if (!is_app(m_registers[reg])) return false; if (!to_app(m_registers[reg])->is_ground()) return false; - enode * n_prime = mk_enode(m_egraph, m_qa, to_app(m_registers[reg])); + enode * n_prime = m_egraph.find(m_registers[reg]); // it is safe to compare the roots because the modifications // on the code tree are chronological. return n->get_root() == n_prime->get_root(); @@ -1341,7 +1342,7 @@ namespace euf { SASSERT(is_app(m_registers[reg])); app * p = to_app(m_registers[reg]); if (p->is_ground()) { - enode * e = mk_enode(m_egraph, m_qa, p); + enode * e = m_egraph.find(p); if (!e->has_lbl_hash()) e->set_lbl_hash(m_egraph); return e->get_lbl_hash(); @@ -1887,7 +1888,7 @@ namespace euf { // We have to provide the number of expected arguments because we have flat-assoc applications such as +. // Flat-assoc applications may have arbitrary number of arguments. enode * get_first_f_app(func_decl * lbl, unsigned num_expected_args, enode * first) { - for (enode* curr : enode_class(first)) { + for (enode* curr : euf::enode_class(first)) { if (curr->get_decl() == lbl && curr->is_cgr() && curr->num_args() == num_expected_args) { update_max_generation(curr, first); return curr; @@ -1919,7 +1920,7 @@ namespace euf { switch (num_args) { case 1: m_args[0] = m_registers[pc->m_iregs[0]]->get_root(); - for (enode* n : enode_class(r)) { + for (enode* n : euf::enode_class(r)) { if (n->get_decl() == f && n->get_arg(0)->get_root() == m_args[0]) { update_max_generation(n, r); @@ -1930,7 +1931,7 @@ namespace euf { case 2: m_args[0] = m_registers[pc->m_iregs[0]]->get_root(); m_args[1] = m_registers[pc->m_iregs[1]]->get_root(); - for (enode* n : enode_class(r)) { + for (enode* n : euf::enode_class(r)) { if (n->get_decl() == f && n->get_arg(0)->get_root() == m_args[0] && n->get_arg(1)->get_root() == m_args[1]) { @@ -1943,7 +1944,7 @@ namespace euf { m_args.reserve(num_args+1, 0); for (unsigned i = 0; i < num_args; i++) m_args[i] = m_registers[pc->m_iregs[i]]->get_root(); - for (enode* n : enode_class(r)) { + for (enode* n : euf::enode_class(r)) { if (n->get_decl() == f) { unsigned i = 0; for (; i < num_args; i++) { @@ -2055,17 +2056,13 @@ namespace euf { enode_vector * interpreter::mk_depth1_vector(enode * n, func_decl * f, unsigned i) { enode_vector * v = mk_enode_vector(); n = n->get_root(); - enode_vector::const_iterator it = n->begin_parents(); - enode_vector::const_iterator end = n->end_parents(); - for (; it != end; ++it) { - enode * p = *it; + for (enode* p : euf::enode_parents(n)) { if (p->get_decl() == f && i < p->num_args() && m_egraph.is_relevant(p) && p->is_cgr() && - p->get_arg(i)->get_root() == n) { + p->get_arg(i)->get_root() == n) v->push_back(p); - } } return v; } @@ -2081,10 +2078,7 @@ namespace euf { return nullptr; unsigned num_args = n->num_args(); enode_vector * v = mk_enode_vector(); - enode_vector::const_iterator it1 = n->begin_parents(); - enode_vector::const_iterator end1 = n->end_parents(); - for (; it1 != end1; ++it1) { - enode * p = *it1; + for (enode* p : euf::enode_parents(n)) { if (p->get_decl() == j2->m_decl && m_egraph.is_relevant(p) && p->num_args() > j2->m_arg_pos && @@ -2092,10 +2086,7 @@ namespace euf { p->get_arg(j2->m_arg_pos)->get_root() == n) { // p is in joint2 p = p->get_root(); - enode_vector::const_iterator it2 = p->begin_parents(); - enode_vector::const_iterator end2 = p->end_parents(); - for (; it2 != end2; ++it2) { - enode * p2 = *it2; + for (enode* p2 : euf::enode_parents(p)) { if (p2->get_decl() == f && num_args == n->num_args() && num_args == p2->num_args() && @@ -2816,27 +2807,27 @@ namespace euf { ast_manager & m; compiler & m_compiler; ptr_vector m_trees; // mapping: func_label -> tree - mam_trail_stack & m_trail_stack; + euf::solver& ctx; #ifdef Z3DEBUG egraph * m_egraph; #endif - class mk_tree_trail : public mam_trail { + class mk_tree_trail : public trail { ptr_vector & m_trees; unsigned m_lbl_id; public: mk_tree_trail(ptr_vector & t, unsigned id):m_trees(t), m_lbl_id(id) {} - void undo(mam_impl & m) override { + void undo(euf::solver & m) override { dealloc(m_trees[m_lbl_id]); m_trees[m_lbl_id] = nullptr; } }; public: - code_tree_map(ast_manager & m, compiler & c, mam_trail_stack & s): + code_tree_map(ast_manager & m, compiler & c, euf::solver& ctx): m(m), m_compiler(c), - m_trail_stack(s) { + ctx(ctx) { } #ifdef Z3DEBUG @@ -2867,7 +2858,7 @@ namespace euf { m_trees[lbl_id] = m_compiler.mk_tree(qa, mp, first_idx, false); SASSERT(m_trees[lbl_id]->expected_num_args() == p->get_num_args()); DEBUG_CODE(m_trees[lbl_id]->set_egraph(m_egraph);); - m_trail_stack.push(mk_tree_trail(m_trees, lbl_id)); + ctx.push(mk_tree_trail(m_trees, lbl_id)); } else { code_tree * tree = m_trees[lbl_id]; @@ -2880,7 +2871,7 @@ namespace euf { } } DEBUG_CODE(m_trees[lbl_id]->get_patterns().push_back(mp); - m_trail_stack.push(push_back_trail(m_trees[lbl_id]->get_patterns()));); + ctx.push(push_back_trail(m_trees[lbl_id]->get_patterns()));); TRACE("trigger_bug", tout << "after add_pattern, first_idx: " << first_idx << "\n"; m_trees[lbl_id]->display(tout);); } @@ -3033,11 +3024,11 @@ namespace euf { // // ------------------------------------ class mam_impl : public mam { + euf::solver& ctx; egraph & m_egraph; - std::function m_add_instance; + ematch & m_ematch; ast_manager & m; bool m_use_filters; - mam_trail_stack m_trail_stack; label_hasher m_lbl_hasher; code_tree_manager m_ct_manager; compiler m_compiler; @@ -3072,23 +3063,31 @@ namespace euf { // temporary field used to collect candidates ptr_vector m_todo; - obj_hashtable m_shared_enodes; // ground terms that appear in patterns. + enode * m_root { nullptr }; // temp field + enode * m_other { nullptr }; // temp field + bool m_check_missing_instances { false }; - enode * m_r1; // temp field - enode * m_r2; // temp field - - class add_shared_enode_trail; - friend class add_shared_enode_trail; - - class add_shared_enode_trail : public mam_trail { - enode * m_enode; + class reset_to_match : public trail { + mam_impl& i; public: - add_shared_enode_trail(enode * n):m_enode(n) {} - void undo(mam_impl & m) override { m.m_shared_enodes.erase(m_enode); } + reset_to_match(mam_impl& i):i(i) {} + void undo(euf::solver& ctx) override { + if (i.m_to_match.empty()) + return; + for (code_tree* t : i.m_to_match) + t->reset_candidates(); + i.m_to_match.reset(); + } + }; + + class reset_new_patterns : public trail { + mam_impl& i; + public: + reset_new_patterns(mam_impl& i):i(i) {} + void undo(euf::solver& ctx) override { + i.m_new_patterns.reset(); + } }; - - - bool m_check_missing_instances{ false }; enode_vector * mk_tmp_vector() { enode_vector * r = m_pool.mk(); @@ -3103,8 +3102,10 @@ namespace euf { void add_candidate(code_tree * t, enode * app) { if (t != nullptr) { TRACE("mam_candidate", tout << "adding candidate:\n" << mk_ll_pp(app->get_expr(), m);); - if (!t->has_candidates()) + if (!t->has_candidates()) { m_to_match.push_back(t); + ctx.push(reset_to_match(*this)); + } t->add_candidate(app); } } @@ -3126,7 +3127,7 @@ namespace euf { void update_lbls(enode * n, unsigned elem) { approx_set & r_lbls = n->get_root()->get_lbls(); if (!r_lbls.may_contain(elem)) { - m_trail_stack.push(mam_value_trail(r_lbls)); + ctx.push(mam_value_trail(r_lbls)); r_lbls.insert(elem); } } @@ -3138,7 +3139,7 @@ namespace euf { TRACE("mam_bug", tout << "update_clbls: " << lbl->get_name() << " is already clbl: " << m_is_clbl[lbl_id] << "\n";); if (m_is_clbl[lbl_id]) return; - m_trail_stack.push(set_bitvector_trail(m_is_clbl, lbl_id)); + ctx.push(set_bitvector_trail(m_is_clbl, lbl_id)); SASSERT(m_is_clbl[lbl_id]); unsigned h = m_lbl_hasher(lbl); for (enode* app : m_egraph.enodes_of(lbl)) { @@ -3158,7 +3159,7 @@ namespace euf { enode * c = app->get_arg(i); approx_set & r_plbls = c->get_root()->get_plbls(); if (!r_plbls.may_contain(elem)) { - m_trail_stack.push(mam_value_trail(r_plbls)); + ctx.push(mam_value_trail(r_plbls)); r_plbls.insert(elem); TRACE("trigger_bug", tout << "updating plabels of:\n" << mk_ismt2_pp(c->get_root()->get_expr(), m) << "\n"; tout << "new_elem: " << static_cast(elem) << "\n"; @@ -3179,7 +3180,7 @@ namespace euf { TRACE("mam_bug", tout << "update_plbls: " << lbl->get_name() << " is already plbl: " << m_is_plbl[lbl_id] << "\n";); if (m_is_plbl[lbl_id]) return; - m_trail_stack.push(set_bitvector_trail(m_is_plbl, lbl_id)); + ctx.push(set_bitvector_trail(m_is_plbl, lbl_id)); SASSERT(m_is_plbl[lbl_id]); SASSERT(is_plbl(lbl)); unsigned h = m_lbl_hasher(lbl); @@ -3226,7 +3227,7 @@ namespace euf { p = p->m_child; } curr->m_code = mk_code(qa, mp, pat_idx); - m_trail_stack.push(new_obj_trail(curr->m_code)); + ctx.push(new_obj_trail(curr->m_code)); return head; } @@ -3249,7 +3250,7 @@ namespace euf { insert_code(t, qa, mp, p->m_pattern_idx); } else { - m_trail_stack.push(set_ptr_trail(t->m_first_child)); + ctx.push(set_ptr_trail(t->m_first_child)); t->m_first_child = mk_path_tree(p->m_child, qa, mp); } } @@ -3259,9 +3260,9 @@ namespace euf { insert_code(t, qa, mp, p->m_pattern_idx); } else { - m_trail_stack.push(set_ptr_trail(t->m_code)); + ctx.push(set_ptr_trail(t->m_code)); t->m_code = mk_code(qa, mp, p->m_pattern_idx); - m_trail_stack.push(new_obj_trail(t->m_code)); + ctx.push(new_obj_trail(t->m_code)); } } else { @@ -3274,10 +3275,10 @@ namespace euf { prev_sibling = t; t = t->m_sibling; } - m_trail_stack.push(set_ptr_trail(prev_sibling->m_sibling)); + ctx.push(set_ptr_trail(prev_sibling->m_sibling)); prev_sibling->m_sibling = mk_path_tree(p, qa, mp); if (!found_label) { - m_trail_stack.push(value_trail(head->m_filter)); + ctx.push(value_trail(head->m_filter)); head->m_filter.insert(m_lbl_hasher(p->m_label)); } } @@ -3287,7 +3288,7 @@ namespace euf { insert(m_pc[h1][h2], p, qa, mp); } else { - m_trail_stack.push(set_ptr_trail(m_pc[h1][h2])); + ctx.push(set_ptr_trail(m_pc[h1][h2])); m_pc[h1][h2] = mk_path_tree(p, qa, mp); } TRACE("mam_path_tree_updt", @@ -3304,7 +3305,7 @@ namespace euf { insert(m_pp[h1][h2].first, p2, qa, mp); } else { - m_trail_stack.push(set_ptr_trail(m_pp[h1][h2].first)); + ctx.push(set_ptr_trail(m_pp[h1][h2].first)); m_pp[h1][h2].first = mk_path_tree(p1, qa, mp); insert(m_pp[h1][h2].first, p2, qa, mp); } @@ -3321,9 +3322,9 @@ namespace euf { insert(m_pp[h1][h2].second, p2, qa, mp); } else { - SASSERT(m_pp[h1][h2].second == 0); - m_trail_stack.push(set_ptr_trail(m_pp[h1][h2].first)); - m_trail_stack.push(set_ptr_trail(m_pp[h1][h2].second)); + SASSERT(m_pp[h1][h2].second == nullptr); + ctx.push(set_ptr_trail(m_pp[h1][h2].first)); + ctx.push(set_ptr_trail(m_pp[h1][h2].second)); m_pp[h1][h2].first = mk_path_tree(p1, qa, mp); m_pp[h1][h2].second = mk_path_tree(p2, qa, mp); } @@ -3360,7 +3361,7 @@ namespace euf { expr * arg = pat->get_arg(i); if (is_ground(arg)) { pos = i; - return mk_enode(m_egraph, qa, to_app(arg)); + return m_egraph.find(arg); } } return nullptr; @@ -3388,7 +3389,7 @@ namespace euf { SASSERT(is_app(child)); if (to_app(child)->is_ground()) { - enode * n = mk_enode(m_egraph, qa, to_app(child)); + enode * n = m_egraph.find(child); update_plbls(plbl); if (!n->has_lbl_hash()) n->set_lbl_hash(m_egraph); @@ -3458,8 +3459,8 @@ namespace euf { bool is_eq(enode * n1, enode * n2) { return n1->get_root() == n2->get_root() || - (n1->get_root() == m_r1 && n2->get_root() == m_r2) || - (n2->get_root() == m_r1 && n1->get_root() == m_r2); + (n1->get_root() == m_other && n2->get_root() == m_root) || + (n2->get_root() == m_other && n1->get_root() == m_root); } /** @@ -3532,7 +3533,7 @@ namespace euf { #endif TRACE("mam_path_tree", tout << "processing: #" << curr_child->get_expr_id() << "\n";); - for (enode* curr_parent : enode_parents(curr_child)) { + for (enode* curr_parent : euf::enode_parents(curr_child)) { #ifdef _PROFILE_PATH_TREE if (curr_parent->is_equality()) t->m_num_eq_visited++; @@ -3682,9 +3683,8 @@ namespace euf { TRACE("mam_new_pat", tout << "matching new patterns:\n";); m_tmp_trees_to_delete.reset(); for (auto const& kv : m_new_patterns) { - if (!m.inc()) { + if (!m.inc()) break; - } quantifier * qa = kv.first; app * mp = kv.second; SASSERT(m.is_pattern(mp)); @@ -3712,62 +3712,30 @@ namespace euf { for (enode * app : m_egraph.enodes_of(lbl)) if (m_egraph.is_relevant(app)) m_interpreter.execute_core(tmp_tree, app); - m_tmp_trees[lbl_id] = 0; + m_tmp_trees[lbl_id] = nullptr; dealloc(tmp_tree); } m_new_patterns.reset(); } - void collect_ground_exprs(quantifier * qa, app * mp) { - ptr_buffer todo; - unsigned num_patterns = mp->get_num_args(); - for (unsigned i = 0; i < num_patterns; i++) { - app * pat = to_app(mp->get_arg(i)); - TRACE("mam_pat", tout << mk_ismt2_pp(qa, m) << "\npat:\n" << mk_ismt2_pp(pat, m) << "\n";); - SASSERT(!pat->is_ground()); - todo.push_back(pat); - } - while (!todo.empty()) { - app * n = todo.back(); - todo.pop_back(); - if (n->is_ground()) { - enode * e = mk_enode(m_egraph, qa, n); - m_trail_stack.push(add_shared_enode_trail(e)); - m_shared_enodes.insert(e); - } - else { - unsigned num_args = n->get_num_args(); - for (unsigned i = 0; i < num_args; i++) { - expr * arg = n->get_arg(i); - if (is_app(arg)) - todo.push_back(to_app(arg)); - } - } - } - } - - public: - mam_impl(egraph & ctx, std::function& add_instance, bool use_filters): - m_egraph(ctx), - m_add_instance(add_instance), + mam_impl(euf::solver & ctx, ematch& ematch, bool use_filters): + ctx(ctx), + m_egraph(ctx.get_egraph()), + m_ematch(ematch), m(ctx.get_manager()), m_use_filters(use_filters), - m_trail_stack(*this), - m_ct_manager(m_lbl_hasher, m_trail_stack), - m_compiler(ctx, m_ct_manager, m_lbl_hasher, use_filters), - m_interpreter(ctx, *this, use_filters), - m_trees(m, m_compiler, m_trail_stack), - m_region(m_trail_stack.get_region()), - m_r1(nullptr), - m_r2(nullptr) { - DEBUG_CODE(m_trees.set_egraph(&ctx);); + m_ct_manager(m_lbl_hasher, ctx), + m_compiler(m_egraph, m_ct_manager, m_lbl_hasher, use_filters), + m_interpreter(m_egraph, *this, use_filters), + m_trees(m, m_compiler, ctx), + m_region(ctx.get_region()) { + DEBUG_CODE(m_trees.set_egraph(&m_egraph);); DEBUG_CODE(m_check_missing_instances = false;); reset_pp_pc(); } ~mam_impl() override { - m_trail_stack.reset(); } void add_pattern(quantifier * qa, app * mp) override { @@ -3778,38 +3746,21 @@ namespace euf { // Ground patterns are discarded. // However, the simplifier may turn a non-ground pattern into a ground one. // So, we should check it again here. - unsigned num_patterns = mp->get_num_args(); - for (unsigned i = 0; i < num_patterns; i++) - if (is_ground(mp->get_arg(i))) + for (expr* arg : *mp) + if (is_ground(arg)) return; // ignore multi-pattern containing ground pattern. update_filters(qa, mp); - collect_ground_exprs(qa, mp); m_new_patterns.push_back(qp_pair(qa, mp)); + ctx.push(reset_new_patterns(*this)); // The matching abstract machine implements incremental // e-matching. So, for a multi-pattern [ p_1, ..., p_n ], // we have to make n insertions. In the i-th insertion, // the pattern p_i is assumed to be the first one. - for (unsigned i = 0; i < num_patterns; i++) + for (unsigned i = 0; i < mp->get_num_args(); i++) m_trees.add_pattern(qa, mp, i); } - void push_scope() override { - m_trail_stack.push_scope(); - } - - void pop_scope(unsigned num_scopes) override { - if (!m_to_match.empty()) { - for (code_tree* t : m_to_match) { - t->reset_candidates(); - } - m_to_match.reset(); - } - m_new_patterns.reset(); - m_trail_stack.pop_scope(num_scopes); - } - void reset() override { - m_trail_stack.reset(); m_trees.reset(); m_to_match.reset(); m_new_patterns.reset(); @@ -3828,7 +3779,7 @@ namespace euf { return out; } - void match() override { + void propagate() override { TRACE("trigger_bug", tout << "match\n"; display(tout);); for (code_tree* t : m_to_match) { SASSERT(t->has_candidates()); @@ -3887,12 +3838,9 @@ namespace euf { unsigned min_gen = 0, max_gen = 0; m_interpreter.get_min_max_top_generation(min_gen, max_gen); UNREACHABLE(); -// m_add_instance(qa, pat, num_bindings, bindings, nullptr, max_generation, min_gen, max_gen); + // m_ematch.on_binding(qa, pat, bindings); // max_generation); // , min_gen, max_gen; } - bool is_shared(enode * n) const override { - return !m_shared_enodes.empty() && m_shared_enodes.contains(n); - } // This method is invoked when n becomes relevant. // If lazy == true, then n is not added to the list of candidate enodes for matching. That is, the method just updates the lbls. @@ -3919,49 +3867,71 @@ namespace euf { } } - bool has_work() const override { + bool can_propagate() const override { return !m_to_match.empty() || !m_new_patterns.empty(); } - void add_eq_eh(enode * r1, enode * r2) override { - flet l1(m_r1, r1); - flet l2(m_r2, r2); + void on_merge(enode * root, enode * other) override { + flet l1(m_other, other); + flet l2(m_root, root); - TRACE("mam", tout << "add_eq_eh: #" << r1->get_expr_id() << " #" << r2->get_expr_id() << "\n";); + TRACE("mam", tout << "add_eq_eh: #" << other->get_expr_id() << " #" << root->get_expr_id() << "\n";); TRACE("mam_inc_bug_detail", m_egraph.display(tout);); TRACE("mam_inc_bug", - tout << "before:\n#" << r1->get_expr_id() << " #" << r2->get_expr_id() << "\n"; - tout << "r1.lbls: " << r1->get_lbls() << "\n"; - tout << "r2.lbls: " << r2->get_lbls() << "\n"; - tout << "r1.plbls: " << r1->get_plbls() << "\n"; - tout << "r2.plbls: " << r2->get_plbls() << "\n";); + tout << "before:\n#" << other->get_expr_id() << " #" << root->get_expr_id() << "\n"; + tout << "other.lbls: " << other->get_lbls() << "\n"; + tout << "root.lbls: " << root->get_lbls() << "\n"; + tout << "other.plbls: " << other->get_plbls() << "\n"; + tout << "root.plbls: " << root->get_plbls() << "\n";); - process_pc(r1, r2); - process_pc(r2, r1); - process_pp(r1, r2); + process_pc(other, root); + process_pc(root, other); + process_pp(other, root); - approx_set r1_plbls = r1->get_plbls(); - approx_set & r2_plbls = r2->get_plbls(); - approx_set r1_lbls = r1->get_lbls(); - approx_set & r2_lbls = r2->get_lbls(); + approx_set other_plbls = other->get_plbls(); + approx_set & root_plbls = root->get_plbls(); + approx_set other_lbls = other->get_lbls(); + approx_set & root_lbls = root->get_lbls(); - m_trail_stack.push(mam_value_trail(r2_lbls)); - m_trail_stack.push(mam_value_trail(r2_plbls)); - r2_lbls |= r1_lbls; - r2_plbls |= r1_plbls; + ctx.push(mam_value_trail(root_lbls)); + ctx.push(mam_value_trail(root_plbls)); + root_lbls |= other_lbls; + root_plbls |= other_plbls; TRACE("mam_inc_bug", tout << "after:\n"; - tout << "r1.lbls: " << r1->get_lbls() << "\n"; - tout << "r2.lbls: " << r2->get_lbls() << "\n"; - tout << "r1.plbls: " << r1->get_plbls() << "\n"; - tout << "r2.plbls: " << r2->get_plbls() << "\n";); - SASSERT(approx_subset(r1->get_plbls(), r2->get_plbls())); - SASSERT(approx_subset(r1->get_lbls(), r2->get_lbls())); + tout << "other.lbls: " << other->get_lbls() << "\n"; + tout << "root.lbls: " << root->get_lbls() << "\n"; + tout << "other.plbls: " << other->get_plbls() << "\n"; + tout << "root.plbls: " << root->get_plbls() << "\n";); + SASSERT(approx_subset(other->get_plbls(), root->get_plbls())); + SASSERT(approx_subset(other->get_lbls(), root->get_lbls())); } }; - mam* mam::mk(egraph& ctx, std::function& add_instance) { - return alloc(mam_impl, ctx, add_instance, true); + void mam::ground_subterms(expr* e, ptr_vector& ground) { + ground.reset(); + expr_fast_mark1 mark; + ptr_buffer todo; + if (is_app(e)) + todo.push_back(to_app(e)); + while (!todo.empty()) { + app * n = todo.back(); + todo.pop_back(); + if (mark.is_marked(n)) + continue; + mark.mark(n); + if (n->is_ground()) + ground.push_back(n); + else { + for (expr* arg : *n) + if (is_app(arg)) + todo.push_back(to_app(arg)); + } + } + } + + mam* mam::mk(euf::solver& ctx, ematch& em) { + return alloc(mam_impl, ctx, em, true); } } diff --git a/src/ast/euf/euf_mam.h b/src/sat/smt/q_mam.h similarity index 58% rename from src/ast/euf/euf_mam.h rename to src/sat/smt/q_mam.h index 1f9236215..194ce7b19 100644 --- a/src/ast/euf/euf_mam.h +++ b/src/sat/smt/q_mam.h @@ -22,6 +22,16 @@ Author: #include namespace euf { + class solver; +}; + +namespace q { + + typedef euf::enode enode; + typedef euf::egraph egraph; + typedef euf::enode_vector enode_vector; + + class ematch; /** \brief Matching Abstract Machine (MAM) @@ -33,34 +43,29 @@ namespace euf { public: - static mam * mk(egraph & ctx, - std::function& add_instance); + static mam * mk(euf::solver& ctx, ematch& em); virtual ~mam() {} virtual void add_pattern(quantifier * q, app * mp) = 0; - virtual void push_scope() = 0; + virtual void propagate() = 0; - virtual void pop_scope(unsigned num_scopes) = 0; - - virtual void match() = 0; + virtual bool can_propagate() const = 0; virtual void rematch(bool use_irrelevant = false) = 0; - - virtual bool has_work() const = 0; - virtual void add_eq_eh(enode * r1, enode * r2) = 0; + virtual void on_merge(enode * root, enode * other) = 0; virtual void reset() = 0; virtual std::ostream& display(std::ostream& out) = 0; - - virtual void on_match(quantifier * q, app * pat, unsigned num_bindings, enode * const * bindings, unsigned max_generation) = 0; - - virtual bool is_shared(enode * n) const = 0; - + virtual bool check_missing_instances() = 0; + + virtual void on_match(quantifier * qa, app * pat, unsigned num_bindings, enode * const * bindings, unsigned max_generation) = 0; + + static void ground_subterms(expr* e, ptr_vector& ground); }; }; diff --git a/src/sat/smt/q_solver.h b/src/sat/smt/q_solver.h index 0e348f1b2..2c91fd822 100644 --- a/src/sat/smt/q_solver.h +++ b/src/sat/smt/q_solver.h @@ -55,6 +55,8 @@ namespace q { expr_ref_vector const& expand(quantifier* q); + friend class ematch; + public: solver(euf::solver& ctx, family_id fid); @@ -75,6 +77,7 @@ namespace q { euf::theory_var mk_var(euf::enode* n) override; void init_search() override; void finalize_model(model& mdl) override; + bool is_shared(euf::theory_var v) const override { return true; } ast_manager& get_manager() { return m; } sat::literal_vector const& universal() const { return m_universal; } diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index b887e6da2..5576672d9 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -152,7 +152,6 @@ namespace euf { sat::literal eq_internalize(expr* a, expr* b); sat::literal eq_internalize(enode* a, enode* b) { return eq_internalize(a->get_expr(), b->get_expr()); } - euf::enode* e_internalize(expr* e); euf::enode* mk_enode(expr* e, bool suppress_args = false); expr_ref mk_eq(expr* e1, expr* e2); expr_ref mk_var_eq(theory_var v1, theory_var v2) { return mk_eq(var2expr(v1), var2expr(v2)); } @@ -173,6 +172,7 @@ namespace euf { virtual ~th_euf_solver() {} virtual theory_var mk_var(enode* n); unsigned get_num_vars() const { return m_var2enode.size(); } + euf::enode* e_internalize(expr* e); enode* expr2enode(expr* e) const; enode* var2enode(theory_var v) const { return m_var2enode[v]; } expr* var2expr(theory_var v) const { return var2enode(v)->get_expr(); } diff --git a/src/util/nat_set.h b/src/util/nat_set.h index 94923ae57..005638492 100644 --- a/src/util/nat_set.h +++ b/src/util/nat_set.h @@ -72,13 +72,9 @@ public: } bool empty() const { - svector::const_iterator it = m_timestamps.begin(); - svector::const_iterator end = m_timestamps.end(); - for (; it != end; ++it) { - if (*it > m_curr_timestamp) { + for (auto const& t : m_timestamps) + if (t > m_curr_timestamp) return false; - } - } return true; } };