diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp index 329976213..e47c4f30d 100644 --- a/src/ast/euf/euf_ac_plugin.cpp +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -125,29 +125,39 @@ namespace euf { UNREACHABLE(); } } + + std::ostream& ac_plugin::display_monomial(std::ostream& out, ptr_vector const& m) const { + for (auto n : m) + out << g.bpp(n->n) << " "; + return out; + } + + std::ostream& ac_plugin::display_equation(std::ostream& out, eq const& e) const { + display_monomial(out, monomial(e.l)); + out << " == "; + display_monomial(out, monomial(e.r)); + return out; + } std::ostream& ac_plugin::display(std::ostream& out) const { unsigned i = 0; for (auto const& eq : m_eqs) { out << i << ": " << eq.l << " == " << eq.r << ": "; - for (auto n : monomial(eq.l)) - out << g.bpp(n->n) << " "; - out << "== "; - for (auto n : monomial(eq.r)) - out << g.bpp(n->n) << " "; + display_equation(out, eq); out << "\n"; ++i; } i = 0; for (auto m : m_monomials) { out << i << ": "; - for (auto n : m) - out << g.bpp(n->n) << " "; + display_monomial(out, m); out << "\n"; ++i; } for (auto n : m_nodes) { - out << g.bpp(n->n) << " r: " << n->root_id() << "\n"; + if (!n) + continue; + out << g.bpp(n->n) << " r: " << n->root_id() << " - "; out << "lhs "; for (auto l : n->lhs) out << l << " "; @@ -279,8 +289,10 @@ namespace euf { } void ac_plugin::propagate() { + TRACE("plugin", display(tout)); while (true) { unsigned eq_id = pick_next_eq(); + TRACE("plugin", tout << "propagate " << eq_id << "\n"); if (eq_id == UINT_MAX) break; @@ -516,6 +528,7 @@ namespace euf { auto& src = m_eqs[src_eq]; auto& dst = m_eqs[dst_eq]; + TRACE("plugin", tout << "superpose: "; display_equation(tout, src); tout << " "; display_equation(tout, dst); tout << "\n";); // AB -> C, AD -> E => BE ~ CD // m_src_ids, m_src_counts contains information about src (call it AD -> E) reset_ids_counts(m_dst_ids, m_dst_count); @@ -523,6 +536,7 @@ namespace euf { m_dst_r.reset(); m_dst_r.append(monomial(dst.r)); unsigned src_r_size = m_src_r.size(); + unsigned dst_r_size = m_dst_r.size(); SASSERT(src_r_size == monomial(src.r).size()); // dst_r contains C // src_r contains E @@ -546,6 +560,37 @@ namespace euf { else m_dst_r.push_back(n); } + // one side is a proper subset of the other + if (m_src_r.size() == src_r_size || m_dst_r.size() == dst_r_size) { + m_src_r.shrink(src_r_size); + return; + } + if (m_src_r.size() == m_dst_r.size()) { + reset_ids_counts(m_dst_ids, m_dst_count); + bool are_equal = true; + for (auto n : m_dst_r) { + unsigned id = n->root_id(); + unsigned dst_count = m_dst_count.get(id, 0); + m_dst_count.set(id, dst_count + 1); + } + for (auto n : m_src_r) { + unsigned id = n->root_id(); + unsigned dst_count = m_dst_count.get(id, 0); + if (dst_count > 0) + m_dst_count[id]--; + else { + are_equal = false; + break; + } + } + if (are_equal) { + m_src_r.shrink(src_r_size); + return; + } + } + + TRACE("plugin", for (auto n : m_src_r) tout << g.bpp(n->n) << " "; tout << "== "; for (auto n : m_dst_r) tout << g.bpp(n->n) << " "; tout << "\n";); + justification j = justify_rewrite(src_eq, dst_eq); if (m_src_r.size() == 1 && m_dst_r.size() == 1) diff --git a/src/ast/euf/euf_ac_plugin.h b/src/ast/euf/euf_ac_plugin.h index ee8b602eb..ccb961c65 100644 --- a/src/ast/euf/euf_ac_plugin.h +++ b/src/ast/euf/euf_ac_plugin.h @@ -105,7 +105,7 @@ namespace euf { node* mk_node(enode* n); void merge(node* r1, node* r2, justification j); - bool is_op(enode* n) const { auto d = n->get_decl(); return d && m_fid == d->get_family_id() && m_op == d->get_kind(); } + bool is_op(enode* n) const { auto d = n->get_decl(); return d && m_fid == d->get_family_id() && m_op == d->get_decl_kind(); } std::function m_undo_notify; void push_undo(undo_kind k); @@ -148,6 +148,9 @@ namespace euf { void propagate_shared(); void simplify_shared(unsigned monomial_id); + std::ostream& display_monomial(std::ostream& out, ptr_vector const& m) const; + std::ostream& display_equation(std::ostream& out, eq const& e) const; + public: ac_plugin(egraph& g, unsigned fid, unsigned op); diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index dbcfb51d2..b0b098fff 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -140,10 +140,12 @@ namespace euf { auto* p = get_plugin(n); if (p) p->register_node(n); - for (auto* arg : enode_args(n)) { - auto* p_arg = get_plugin(arg); - if (p != p_arg) - p_arg->register_shared(arg); + if (!n->is_equality()) { + for (auto* arg : enode_args(n)) { + auto* p_arg = get_plugin(arg); + if (p != p_arg) + p_arg->register_shared(arg); + } } } @@ -620,6 +622,7 @@ namespace euf { bool egraph::propagate() { force_push(); + propagate_plugins(); for (unsigned i = 0; i < m_to_merge.size() && m.limit().inc() && !inconsistent(); ++i) { auto const& w = m_to_merge[i]; if (w.j.is_congruence())