diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp index 14eb90a86..89d42bf43 100644 --- a/src/ast/euf/euf_ac_plugin.cpp +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -63,7 +63,7 @@ TODOs: namespace euf { - ac_plugin::ac_plugin(egraph& g, unsigned fid, unsigned op): + ac_plugin::ac_plugin(egraph& g, unsigned fid, unsigned op) : plugin(g), m_fid(fid), m_op(op), m_dep_manager(get_region()), m_hash(*this), m_eq(*this), m_monomial_table(m_hash, m_eq) @@ -107,7 +107,7 @@ namespace euf { n->~node(); break; } - case is_add_monomial: { + case is_add_monomial: { m_monomials.pop_back(); break; } @@ -122,7 +122,7 @@ namespace euf { break; } case is_update_eq: { - auto const & [idx, eq] = m_update_eq_trail.back(); + auto const& [idx, eq] = m_update_eq_trail.back(); m_eqs[idx] = eq; m_update_eq_trail.pop_back(); break; @@ -160,7 +160,7 @@ namespace euf { display_monomial(out, monomial(e.r)); return out; } - + std::ostream& ac_plugin::display(std::ostream& out) const { unsigned i = 0; for (auto const& eq : m_eqs) { @@ -198,10 +198,10 @@ namespace euf { void ac_plugin::merge_eh(enode* l, enode* r, justification j) { if (l == r) return; - if (!is_op(l) && !is_op(r)) + if (!is_op(l) && !is_op(r)) merge(mk_node(l), mk_node(r), j); - else - init_equation(eq(to_monomial(l), to_monomial(r), j)); + else + init_equation(eq(to_monomial(l), to_monomial(r), j)); } void ac_plugin::init_equation(eq const& e) { @@ -248,28 +248,30 @@ namespace euf { } bool ac_plugin::is_sorted(monomial_t const& m) const { - if (m.m_filter.m_tick == m_tick) + if (m.m_bloom.m_tick == m_tick) return true; - for (unsigned i = m.size(); i-- > 1; ) - if (m[i-1]->root_id() > m[i]->root_id()) - return false; + for (unsigned i = m.size(); i-- > 1; ) + if (m[i - 1]->root_id() > m[i]->root_id()) + return false; return true; } uint64_t ac_plugin::filter(monomial_t& m) { - auto& filter = m.m_filter; - if (filter.m_tick == m_tick) - return filter.m_filter; - filter.m_filter = 0; + auto& bloom = m.m_bloom; + if (bloom.m_tick == m_tick) + return bloom.m_filter; + bloom.m_filter = 0; for (auto n : m) - filter.m_filter |= (1ull << (n->root_id() % 64ull)); + bloom.m_filter |= (1ull << (n->root_id() % 64ull)); if (!is_sorted(m)) sort(m); - filter.m_tick = m_tick; - return filter.m_filter; + bloom.m_tick = m_tick; + return bloom.m_filter; } bool ac_plugin::can_be_subset(monomial_t& subset, monomial_t& superset) { + if (subset.size() > superset.size()) + return false; auto f1 = filter(subset); auto f2 = filter(superset); return (f1 | f2) == f2; @@ -278,9 +280,9 @@ namespace euf { void ac_plugin::merge(node* root, node* other, justification j) { for (auto n : equiv(other)) n->root = root; - m_merge_trail.push_back({ other, root->shared.size(), root->eqs.size()}); + m_merge_trail.push_back({ other, root->shared.size(), root->eqs.size() }); for (auto eq_id : other->eqs) - set_processed(eq_id, false); + set_status(eq_id, eq_status::to_simplify); for (auto m : other->shared) m_shared_todo.insert(m); root->shared.append(other->shared); @@ -346,17 +348,17 @@ namespace euf { void ac_plugin::propagate() { TRACE("plugin", display(tout)); while (true) { + loop_start: unsigned eq_id = pick_next_eq(); TRACE("plugin", tout << "propagate " << eq_id << "\n"); if (eq_id == UINT_MAX) break; + eq& eq = m_eqs[eq_id]; // simplify eq using processed for (auto other_eq : backward_iterator(eq_id)) - if (is_processed(other_eq)) - backward_simplify(eq_id, other_eq); - if (m_backward_simplified) - continue; + if (is_processed(other_eq) && backward_simplify(eq_id, other_eq)) + goto loop_start; // simplify processed using eq for (auto other_eq : forward_iterator(eq_id)) @@ -373,14 +375,14 @@ namespace euf { if (is_to_simplify(other_eq)) forward_simplify(other_eq, eq_id); - set_processed(eq_id, true); + set_status(eq_id, eq_status::processed); } propagate_shared(); } unsigned ac_plugin::pick_next_eq() { while (!m_to_simplify_todo.empty()) { - unsigned id = *m_to_simplify_todo.begin(); + unsigned id = *m_to_simplify_todo.begin(); if (id < m_eqs.size() && is_to_simplify(id)) return id; m_to_simplify_todo.remove(id); @@ -388,17 +390,28 @@ namespace euf { return UINT_MAX; } - void ac_plugin::set_processed(unsigned id, bool f) { + void ac_plugin::set_status(unsigned id, eq_status s) { auto& eq = m_eqs[id]; - if (eq.is_processed == f) + if (eq.status == eq_status::is_dead) return; - if (f) + if (s == eq_status::to_simplify && are_equal(monomial(eq.l), monomial(eq.r))) + s = eq_status::is_dead; + + if (eq.status != s) { + m_update_eq_trail.push_back({ id, eq }); + eq.status = s; + push_undo(is_update_eq); + } + switch (s) { + case eq_status::processed: + case eq_status::is_dead: m_to_simplify_todo.remove(id); - else + break; + case eq_status::to_simplify: m_to_simplify_todo.insert(id); - m_update_eq_trail.push_back({ id, eq }); - eq.is_processed = f; - push_undo(is_update_eq); + orient_equation(eq); + break; + } } // @@ -408,7 +421,7 @@ namespace euf { auto const& eq = m_eqs[eq_id]; m_src_r.reset(); m_src_r.append(monomial(eq.r).m_nodes); - init_ids_counts(monomial(eq.l), m_src_ids, m_src_count); + init_ref_counts(monomial(eq.l), m_src_l_counts); init_overlap_iterator(eq_id, monomial(eq.l)); return m_eq_occurs; } @@ -419,9 +432,9 @@ namespace euf { // unsigned_vector const& ac_plugin::backward_iterator(unsigned eq_id) { auto const& eq = m_eqs[eq_id]; - init_ids_counts(monomial(eq.r), m_dst_ids, m_dst_count); + init_ref_counts(monomial(eq.r), m_dst_r_counts); + init_ref_counts(monomial(eq.l), m_dst_l_counts); init_overlap_iterator(eq_id, monomial(eq.r)); - m_backward_simplified = false; return m_eq_occurs; } @@ -454,7 +467,8 @@ namespace euf { auto& eq = m_eqs[eq_id]; m_src_r.reset(); m_src_r.append(monomial(eq.r).m_nodes); - init_ids_counts(monomial(eq.l), m_src_ids, m_src_count); + init_ref_counts(monomial(eq.l), m_src_l_counts); + init_ref_counts(monomial(eq.r), m_src_r_counts); unsigned min_r = UINT_MAX; node* min_n = nullptr; for (auto n : monomial(eq.l)) @@ -465,22 +479,13 @@ namespace euf { return min_n->eqs; } - void ac_plugin::init_ids_counts(monomial_t const& monomial, unsigned_vector& ids, unsigned_vector& counts) { - reset_ids_counts(ids, counts); - for (auto n : monomial) { - unsigned id = n->root_id(); - counts.setx(id, counts.get(id, 0) + 1, 0); - ids.push_back(id); - } + void ac_plugin::init_ref_counts(monomial_t const& monomial, ref_counts& counts) { + counts.reset(); + for (auto n : monomial) + counts.inc(n->root_id(), 1); } - void ac_plugin::reset_ids_counts(unsigned_vector& ids, unsigned_vector& counts) { - for (auto id : ids) - counts[id] = 0; - ids.reset(); - } - - void ac_plugin::forward_simplify(unsigned dst_eq, unsigned src_eq) { + void ac_plugin::forward_simplify(unsigned dst_eq, unsigned src_eq) { if (src_eq == dst_eq) return; @@ -495,7 +500,12 @@ namespace euf { if (!can_be_subset(monomial(src.l), monomial(dst.r))) return; - reset_ids_counts(m_dst_ids, m_dst_count); + if (forward_subsumes(src_eq, dst_eq)) { + set_status(dst_eq, eq_status::is_dead); + return; + } + + m_dst_r_counts.reset(); unsigned src_l_size = monomial(src.l).size(); unsigned src_r_size = m_src_r.size(); @@ -505,16 +515,13 @@ namespace euf { unsigned num_overlap = 0; for (auto n : monomial(dst.r)) { unsigned id = n->root_id(); - unsigned count = m_src_count.get(id, 0); - if (count == 0) + unsigned count = m_src_l_counts[id]; + if (count == 0) m_src_r.push_back(n); - else { - unsigned dst_count = m_dst_count.get(id, 0); - if (dst_count >= count) - m_src_r.push_back(n); - else - m_dst_count.set(id, dst_count + 1), m_dst_ids.push_back(id), ++num_overlap; - } + else if (m_dst_r_counts[id] >= count) + m_src_r.push_back(n); + else + m_dst_r_counts.inc(id, 1), ++num_overlap; } // The dst.r has to be a superset of src.l, otherwise simplification does not apply if (num_overlap == src_l_size) { @@ -527,9 +534,9 @@ namespace euf { m_src_r.shrink(src_r_size); } - void ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) { + bool ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) { if (src_eq == dst_eq) - return; + return false; auto& src = m_eqs[src_eq]; auto& dst = m_eqs[dst_eq]; @@ -539,30 +546,80 @@ namespace euf { // check that src.l is a subset of dst.r if (!can_be_subset(monomial(src.l), monomial(dst.r))) - return; - if (!is_subset(monomial(src.l))) - return; - + return false; + if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) + return false; + if (backward_subsumes(src_eq, dst_eq)) { + set_status(dst_eq, eq_status::is_dead); + return true; + } // dst_rhs := dst_rhs - src_lhs + src_rhs auto new_r = rewrite(monomial(src.r), monomial(dst.r)); m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] }); m_eqs[dst_eq].r = new_r; m_eqs[dst_eq].j = justify_rewrite(src_eq, dst_eq); push_undo(is_update_eq); - m_backward_simplified = true; + return true; } - bool ac_plugin::subsumes(unsigned src_eq, unsigned dst_eq) { + // dst_eq is fixed, dst_count is pre-computed for monomial(dst.l) + // dst2_counts is pre-computed for monomial(dst.r). + // is dst_eq subsumed by src_eq? + bool ac_plugin::backward_subsumes(unsigned src_eq, unsigned dst_eq) { auto& src = m_eqs[src_eq]; auto& dst = m_eqs[dst_eq]; if (!can_be_subset(monomial(src.l), monomial(dst.l))) return false; if (!can_be_subset(monomial(src.r), monomial(dst.r))) return false; + unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size(); + if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) + return false; + if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) + return false; + if (!is_subset(m_dst_r_counts, m_src_r_counts, monomial(src.r))) + return false; + // add difference betwen src and dst1 to dst2 + // (also add it to dst1 to make sure same difference isn't counted twice). + for (auto n : monomial(src.l)) { + unsigned id = n->root_id(); + SASSERT(m_src_l_counts[id] >= m_dst_l_counts[id]); + unsigned diff = m_src_l_counts[id] - m_dst_l_counts[id]; + if (diff > 0) { + m_dst_l_counts.inc(id, diff); + m_dst_r_counts.inc(id, diff); + } + } + // now dst2 and src2 should align and have the same elements. + // since src.r is a subset of dst.r we iterate over dst.r + return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); + } - NOT_IMPLEMENTED_YET(); - // dst.l \ src.l = dst.r \ dst.r - return false; + // src_counts, src2_counts are initialized for src_eq + bool ac_plugin::forward_subsumes(unsigned src_eq, unsigned dst_eq) { + auto& src = m_eqs[src_eq]; + auto& dst = m_eqs[dst_eq]; + if (!can_be_subset(monomial(src.l), monomial(dst.l))) + return false; + if (!can_be_subset(monomial(src.r), monomial(dst.r))) + return false; + unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size(); + if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) + return false; + if (!is_superset(m_src_l_counts, m_dst_l_counts, monomial(dst.l))) + return false; + if (!is_subset(m_src_r_counts, m_dst_r_counts, monomial(dst.r))) + return false; + for (auto n : monomial(src.l)) { + unsigned id = n->root_id(); + SASSERT(m_src_l_counts[id] >= m_dst_l_counts[id]); + unsigned diff = m_src_l_counts[id] - m_dst_l_counts[id]; + if (diff > 0) { + m_dst_l_counts.inc(id, diff); + m_dst_r_counts.inc(id, diff); + } + } + return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); } unsigned ac_plugin::rewrite(monomial_t const& src_r, monomial_t const& dst_r) { @@ -573,33 +630,39 @@ namespace euf { // add to m_src_r elements of dst.r that are not in src.l for (auto n : dst_r) { unsigned id = n->root_id(); - unsigned count = m_src_count.get(id, 0); + unsigned count = m_src_l_counts[id]; if (count == 0) m_src_r.push_back(n); else - --m_src_count[id]; + m_src_l_counts.inc(id, -1); } return to_monomial(nullptr, m_src_r); } - bool ac_plugin::is_subset(monomial_t const& dst) { - reset_ids_counts(m_src_ids, m_src_count); - for (auto n : dst) { + // check that src is a subset of dst, where dst_counts are precomputed + bool ac_plugin::is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src) { + SASSERT(&dst_counts != &src_counts); + src_counts.reset(); + for (auto n : src) { unsigned id = n->root_id(); - unsigned dst_count = m_dst_count.get(id, 0); + unsigned dst_count = dst_counts[id]; if (dst_count == 0) return false; - else { - unsigned src_count = m_src_count.get(id, 0); - if (src_count >= dst_count) - return false; - else - m_src_count.set(id, src_count + 1), m_src_ids.push_back(id); - } + else if (src_counts[id] >= dst_count) + return false; + else + src_counts.inc(id, 1); } return true; } + // check that dst is a superset of dst, where src_counts are precomputed + bool ac_plugin::is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst) { + SASSERT(&dst_counts != &src_counts); + init_ref_counts(dst, dst_counts); + return all_of(src_counts, [&](unsigned idx) { return dst_counts[idx] <= src_counts[idx]; }); + } + void ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) { if (src_eq == dst_eq) return; @@ -609,7 +672,7 @@ namespace euf { TRACE("plugin", tout << "superpose: "; display_equation(tout, src); tout << " "; display_equation(tout, dst); tout << "\n";); // AB -> C, AD -> E => BE ~ CD // m_src_ids, m_src_counts contains information about src (call it AD -> E) - reset_ids_counts(m_dst_ids, m_dst_count); + m_dst_l_counts.reset(); m_dst_r.reset(); m_dst_r.append(monomial(dst.r).m_nodes); @@ -622,63 +685,56 @@ namespace euf { // compute BE, initialize dst_ids, dst_counts for (auto n : monomial(dst.l)) { unsigned id = n->root_id(); - unsigned src_count = m_src_count.get(id, 0); - unsigned dst_count = m_dst_count.get(id, 0); - m_dst_count.set(id, dst_count + 1); - m_dst_ids.push_back(id); - if (src_count < dst_count) - m_src_r.push_back(n); + if (m_src_l_counts[id] < m_dst_l_counts[id]) + m_src_r.push_back(n); + m_dst_l_counts.inc(id, 1); } // compute CD for (auto n : monomial(src.l)) { unsigned id = n->root_id(); - unsigned dst_count = m_dst_count.get(id, 0); - if (dst_count > 0) - --m_dst_count[id]; + if (m_dst_l_counts[id] > 0) + m_dst_l_counts.inc(id, -1); else m_dst_r.push_back(n); } - // one side is a proper subset of the other - if (m_src_r.size() == src_r_size || m_dst_r.size() == dst_r_size) { + + if (are_equal(m_src_r, m_dst_r)) { m_src_r.shrink(src_r_size); return; } - if (m_src_r.size() == m_dst_r.size()) { - reset_ids_counts(m_dst_ids, m_dst_count); - bool are_equal = true; - for (auto n : m_dst_r) { - unsigned id = n->root_id(); - unsigned dst_count = m_dst_count.get(id, 0); - m_dst_count.set(id, dst_count + 1); - } - for (auto n : m_src_r) { - unsigned id = n->root_id(); - unsigned dst_count = m_dst_count.get(id, 0); - if (dst_count > 0) - m_dst_count[id]--; - else { - are_equal = false; - break; - } - } - if (are_equal) { - m_src_r.shrink(src_r_size); - return; - } - } TRACE("plugin", for (auto n : m_src_r) tout << g.bpp(n->n) << " "; tout << "== "; for (auto n : m_dst_r) tout << g.bpp(n->n) << " "; tout << "\n";); - + justification j = justify_rewrite(src_eq, dst_eq); - if (m_src_r.size() == 1 && m_dst_r.size() == 1) - push_merge(m_src_r[0]->n, m_dst_r[0]->n, j); - else + if (m_src_r.size() == 1 && m_dst_r.size() == 1) + push_merge(m_src_r[0]->n, m_dst_r[0]->n, j); + else init_equation(eq(to_monomial(nullptr, m_src_r), to_monomial(nullptr, m_dst_r), j)); m_src_r.shrink(src_r_size); } + bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) { + return filter(a) == filter(b) && are_equal(a.m_nodes, b.m_nodes); + } + + bool ac_plugin::are_equal(ptr_vector const& a, ptr_vector const& b) { + if (a.size() != b.size()) + return false; + m_eq_counts.reset(); + for (auto n : a) + m_eq_counts.inc(n->root_id(), 1); + + for (auto n : b) { + unsigned id = n->root_id(); + if (m_eq_counts[id] == 0) + return false; + m_eq_counts.inc(id, -1); + } + return true; + } + // // simple version based on propagating all shared // todo: version touching only newly processed shared, and maintaining incremental data-structures. @@ -712,13 +768,13 @@ namespace euf { while (change) { change = false; auto & m = monomial(s.m); - init_ids_counts(m, m_dst_ids, m_dst_count); + init_ref_counts(m, m_dst_l_counts); init_overlap_iterator(UINT_MAX, m); for (auto eq : m_eq_occurs) { auto& src = m_eqs[eq]; if (!can_be_subset(monomial(src.l), m)) continue; - if (!is_subset(monomial(src.l))) + if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) continue; m_update_shared_trail.push_back({ idx, s }); push_undo(is_update_shared); diff --git a/src/ast/euf/euf_ac_plugin.h b/src/ast/euf/euf_ac_plugin.h index 0b0b06764..ca8f047ae 100644 --- a/src/ast/euf/euf_ac_plugin.h +++ b/src/ast/euf/euf_ac_plugin.h @@ -75,13 +75,16 @@ namespace euf { uint64_t m_filter = 0; }; + enum eq_status { + processed, to_simplify, is_dead + }; + // represent equalities added by merge_eh and by superposition struct eq { eq(unsigned l, unsigned r, justification j): l(l), r(r), j(j) {} unsigned l, r; // refer to monomials - bool is_processed = false; // true if the equality is in the processed set - bool is_alive = true; + eq_status status = to_simplify; justification j; // justification for equality }; @@ -94,7 +97,7 @@ namespace euf { struct monomial_t { ptr_vector m_nodes; - bloom m_filter; + bloom m_bloom; node* operator[](unsigned i) const { return m_nodes[i]; } unsigned size() const { return m_nodes.size(); } node* const* begin() const { return m_nodes.begin(); } @@ -186,34 +189,50 @@ namespace euf { bool is_sorted(monomial_t const& monomial) const; uint64_t filter(monomial_t& m); bool can_be_subset(monomial_t& subset, monomial_t& superset); - bool subsumes(unsigned src_eq, unsigned dst_eq); + bool are_equal(ptr_vector const& a, ptr_vector const& b); + bool are_equal(monomial_t& a, monomial_t& b); + bool backward_subsumes(unsigned src_eq, unsigned dst_eq); + bool forward_subsumes(unsigned src_eq, unsigned dst_eq); void init_equation(eq const& e); bool orient_equation(eq& e); - void set_processed(unsigned eq_id, bool f); + void set_status(unsigned eq_id, eq_status s); unsigned pick_next_eq(); void forward_simplify(unsigned eq_id, unsigned using_eq); - void backward_simplify(unsigned eq_id, unsigned using_eq); + bool backward_simplify(unsigned eq_id, unsigned using_eq); void superpose(unsigned src_eq, unsigned dst_eq); - ptr_vector m_src_r, m_src_l, m_dst_r; - unsigned_vector m_src_ids, m_src_count, m_dst_ids, m_dst_count; + ptr_vector m_src_r, m_src_l, m_dst_r, m_dst_l; + + struct ref_counts { + unsigned_vector ids; + unsigned_vector counts; + void reset() { for (auto idx : ids) counts[idx] = 0; ids.reset(); } + unsigned operator[](unsigned idx) const { return counts.get(idx, 0); } + void inc(unsigned idx, unsigned amount) { counts.reserve(idx + 1, 0); ids.push_back(idx); counts[idx] += amount; } + unsigned const* begin() const { return ids.begin(); } + unsigned const* end() const { return ids.end(); } + }; + ref_counts m_src_l_counts, m_dst_l_counts, m_src_r_counts, m_dst_r_counts, m_eq_counts; unsigned_vector m_eq_occurs; bool_vector m_eq_seen; - bool m_backward_simplified = false; unsigned_vector const& forward_iterator(unsigned eq); unsigned_vector const& superpose_iterator(unsigned eq); unsigned_vector const& backward_iterator(unsigned eq); - void init_ids_counts(monomial_t const& monomial, unsigned_vector& ids, unsigned_vector& counts); - void reset_ids_counts(unsigned_vector& ids, unsigned_vector& counts); + void init_ref_counts(monomial_t const& monomial, ref_counts& counts); void init_overlap_iterator(unsigned eq, monomial_t const& m); - bool is_subset(monomial_t const& dst); + // check that src is a subset of dst, where dst_counts are precomputed + bool is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src); + + // check that dst is a superset of dst, where src_counts are precomputed + bool is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst); unsigned rewrite(monomial_t const& src_r, monomial_t const& dst_r); - bool is_to_simplify(unsigned eq) const { return !m_eqs[eq].is_processed && m_eqs[eq].is_alive; } - bool is_processed(unsigned eq) const { return m_eqs[eq].is_processed && m_eqs[eq].is_alive; } + bool is_to_simplify(unsigned eq) const { return m_eqs[eq].status == eq_status::to_simplify; } + bool is_processed(unsigned eq) const { return m_eqs[eq].status == eq_status::processed; } + bool is_alive(unsigned eq) const { return m_eqs[eq].status != eq_status::is_dead; } justification justify_rewrite(unsigned eq1, unsigned eq2); justification::dependency* justify_equation(unsigned eq);