3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

updated AC simplification

This commit is contained in:
Nikolaj Bjorner 2023-11-15 11:01:51 -08:00
parent d5315e2283
commit bf5e6936c0
2 changed files with 215 additions and 140 deletions

View file

@ -63,7 +63,7 @@ TODOs:
namespace euf {
ac_plugin::ac_plugin(egraph& g, unsigned fid, unsigned op):
ac_plugin::ac_plugin(egraph& g, unsigned fid, unsigned op) :
plugin(g), m_fid(fid), m_op(op),
m_dep_manager(get_region()),
m_hash(*this), m_eq(*this), m_monomial_table(m_hash, m_eq)
@ -107,7 +107,7 @@ namespace euf {
n->~node();
break;
}
case is_add_monomial: {
case is_add_monomial: {
m_monomials.pop_back();
break;
}
@ -122,7 +122,7 @@ namespace euf {
break;
}
case is_update_eq: {
auto const & [idx, eq] = m_update_eq_trail.back();
auto const& [idx, eq] = m_update_eq_trail.back();
m_eqs[idx] = eq;
m_update_eq_trail.pop_back();
break;
@ -160,7 +160,7 @@ namespace euf {
display_monomial(out, monomial(e.r));
return out;
}
std::ostream& ac_plugin::display(std::ostream& out) const {
unsigned i = 0;
for (auto const& eq : m_eqs) {
@ -198,10 +198,10 @@ namespace euf {
void ac_plugin::merge_eh(enode* l, enode* r, justification j) {
if (l == r)
return;
if (!is_op(l) && !is_op(r))
if (!is_op(l) && !is_op(r))
merge(mk_node(l), mk_node(r), j);
else
init_equation(eq(to_monomial(l), to_monomial(r), j));
else
init_equation(eq(to_monomial(l), to_monomial(r), j));
}
void ac_plugin::init_equation(eq const& e) {
@ -248,28 +248,30 @@ namespace euf {
}
bool ac_plugin::is_sorted(monomial_t const& m) const {
if (m.m_filter.m_tick == m_tick)
if (m.m_bloom.m_tick == m_tick)
return true;
for (unsigned i = m.size(); i-- > 1; )
if (m[i-1]->root_id() > m[i]->root_id())
return false;
for (unsigned i = m.size(); i-- > 1; )
if (m[i - 1]->root_id() > m[i]->root_id())
return false;
return true;
}
uint64_t ac_plugin::filter(monomial_t& m) {
auto& filter = m.m_filter;
if (filter.m_tick == m_tick)
return filter.m_filter;
filter.m_filter = 0;
auto& bloom = m.m_bloom;
if (bloom.m_tick == m_tick)
return bloom.m_filter;
bloom.m_filter = 0;
for (auto n : m)
filter.m_filter |= (1ull << (n->root_id() % 64ull));
bloom.m_filter |= (1ull << (n->root_id() % 64ull));
if (!is_sorted(m))
sort(m);
filter.m_tick = m_tick;
return filter.m_filter;
bloom.m_tick = m_tick;
return bloom.m_filter;
}
bool ac_plugin::can_be_subset(monomial_t& subset, monomial_t& superset) {
if (subset.size() > superset.size())
return false;
auto f1 = filter(subset);
auto f2 = filter(superset);
return (f1 | f2) == f2;
@ -278,9 +280,9 @@ namespace euf {
void ac_plugin::merge(node* root, node* other, justification j) {
for (auto n : equiv(other))
n->root = root;
m_merge_trail.push_back({ other, root->shared.size(), root->eqs.size()});
m_merge_trail.push_back({ other, root->shared.size(), root->eqs.size() });
for (auto eq_id : other->eqs)
set_processed(eq_id, false);
set_status(eq_id, eq_status::to_simplify);
for (auto m : other->shared)
m_shared_todo.insert(m);
root->shared.append(other->shared);
@ -346,17 +348,17 @@ namespace euf {
void ac_plugin::propagate() {
TRACE("plugin", display(tout));
while (true) {
loop_start:
unsigned eq_id = pick_next_eq();
TRACE("plugin", tout << "propagate " << eq_id << "\n");
if (eq_id == UINT_MAX)
break;
eq& eq = m_eqs[eq_id];
// simplify eq using processed
for (auto other_eq : backward_iterator(eq_id))
if (is_processed(other_eq))
backward_simplify(eq_id, other_eq);
if (m_backward_simplified)
continue;
if (is_processed(other_eq) && backward_simplify(eq_id, other_eq))
goto loop_start;
// simplify processed using eq
for (auto other_eq : forward_iterator(eq_id))
@ -373,14 +375,14 @@ namespace euf {
if (is_to_simplify(other_eq))
forward_simplify(other_eq, eq_id);
set_processed(eq_id, true);
set_status(eq_id, eq_status::processed);
}
propagate_shared();
}
unsigned ac_plugin::pick_next_eq() {
while (!m_to_simplify_todo.empty()) {
unsigned id = *m_to_simplify_todo.begin();
unsigned id = *m_to_simplify_todo.begin();
if (id < m_eqs.size() && is_to_simplify(id))
return id;
m_to_simplify_todo.remove(id);
@ -388,17 +390,28 @@ namespace euf {
return UINT_MAX;
}
void ac_plugin::set_processed(unsigned id, bool f) {
void ac_plugin::set_status(unsigned id, eq_status s) {
auto& eq = m_eqs[id];
if (eq.is_processed == f)
if (eq.status == eq_status::is_dead)
return;
if (f)
if (s == eq_status::to_simplify && are_equal(monomial(eq.l), monomial(eq.r)))
s = eq_status::is_dead;
if (eq.status != s) {
m_update_eq_trail.push_back({ id, eq });
eq.status = s;
push_undo(is_update_eq);
}
switch (s) {
case eq_status::processed:
case eq_status::is_dead:
m_to_simplify_todo.remove(id);
else
break;
case eq_status::to_simplify:
m_to_simplify_todo.insert(id);
m_update_eq_trail.push_back({ id, eq });
eq.is_processed = f;
push_undo(is_update_eq);
orient_equation(eq);
break;
}
}
//
@ -408,7 +421,7 @@ namespace euf {
auto const& eq = m_eqs[eq_id];
m_src_r.reset();
m_src_r.append(monomial(eq.r).m_nodes);
init_ids_counts(monomial(eq.l), m_src_ids, m_src_count);
init_ref_counts(monomial(eq.l), m_src_l_counts);
init_overlap_iterator(eq_id, monomial(eq.l));
return m_eq_occurs;
}
@ -419,9 +432,9 @@ namespace euf {
//
unsigned_vector const& ac_plugin::backward_iterator(unsigned eq_id) {
auto const& eq = m_eqs[eq_id];
init_ids_counts(monomial(eq.r), m_dst_ids, m_dst_count);
init_ref_counts(monomial(eq.r), m_dst_r_counts);
init_ref_counts(monomial(eq.l), m_dst_l_counts);
init_overlap_iterator(eq_id, monomial(eq.r));
m_backward_simplified = false;
return m_eq_occurs;
}
@ -454,7 +467,8 @@ namespace euf {
auto& eq = m_eqs[eq_id];
m_src_r.reset();
m_src_r.append(monomial(eq.r).m_nodes);
init_ids_counts(monomial(eq.l), m_src_ids, m_src_count);
init_ref_counts(monomial(eq.l), m_src_l_counts);
init_ref_counts(monomial(eq.r), m_src_r_counts);
unsigned min_r = UINT_MAX;
node* min_n = nullptr;
for (auto n : monomial(eq.l))
@ -465,22 +479,13 @@ namespace euf {
return min_n->eqs;
}
void ac_plugin::init_ids_counts(monomial_t const& monomial, unsigned_vector& ids, unsigned_vector& counts) {
reset_ids_counts(ids, counts);
for (auto n : monomial) {
unsigned id = n->root_id();
counts.setx(id, counts.get(id, 0) + 1, 0);
ids.push_back(id);
}
void ac_plugin::init_ref_counts(monomial_t const& monomial, ref_counts& counts) {
counts.reset();
for (auto n : monomial)
counts.inc(n->root_id(), 1);
}
void ac_plugin::reset_ids_counts(unsigned_vector& ids, unsigned_vector& counts) {
for (auto id : ids)
counts[id] = 0;
ids.reset();
}
void ac_plugin::forward_simplify(unsigned dst_eq, unsigned src_eq) {
void ac_plugin::forward_simplify(unsigned dst_eq, unsigned src_eq) {
if (src_eq == dst_eq)
return;
@ -495,7 +500,12 @@ namespace euf {
if (!can_be_subset(monomial(src.l), monomial(dst.r)))
return;
reset_ids_counts(m_dst_ids, m_dst_count);
if (forward_subsumes(src_eq, dst_eq)) {
set_status(dst_eq, eq_status::is_dead);
return;
}
m_dst_r_counts.reset();
unsigned src_l_size = monomial(src.l).size();
unsigned src_r_size = m_src_r.size();
@ -505,16 +515,13 @@ namespace euf {
unsigned num_overlap = 0;
for (auto n : monomial(dst.r)) {
unsigned id = n->root_id();
unsigned count = m_src_count.get(id, 0);
if (count == 0)
unsigned count = m_src_l_counts[id];
if (count == 0)
m_src_r.push_back(n);
else {
unsigned dst_count = m_dst_count.get(id, 0);
if (dst_count >= count)
m_src_r.push_back(n);
else
m_dst_count.set(id, dst_count + 1), m_dst_ids.push_back(id), ++num_overlap;
}
else if (m_dst_r_counts[id] >= count)
m_src_r.push_back(n);
else
m_dst_r_counts.inc(id, 1), ++num_overlap;
}
// The dst.r has to be a superset of src.l, otherwise simplification does not apply
if (num_overlap == src_l_size) {
@ -527,9 +534,9 @@ namespace euf {
m_src_r.shrink(src_r_size);
}
void ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) {
bool ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) {
if (src_eq == dst_eq)
return;
return false;
auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq];
@ -539,30 +546,80 @@ namespace euf {
// check that src.l is a subset of dst.r
if (!can_be_subset(monomial(src.l), monomial(dst.r)))
return;
if (!is_subset(monomial(src.l)))
return;
return false;
if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l)))
return false;
if (backward_subsumes(src_eq, dst_eq)) {
set_status(dst_eq, eq_status::is_dead);
return true;
}
// dst_rhs := dst_rhs - src_lhs + src_rhs
auto new_r = rewrite(monomial(src.r), monomial(dst.r));
m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] });
m_eqs[dst_eq].r = new_r;
m_eqs[dst_eq].j = justify_rewrite(src_eq, dst_eq);
push_undo(is_update_eq);
m_backward_simplified = true;
return true;
}
bool ac_plugin::subsumes(unsigned src_eq, unsigned dst_eq) {
// dst_eq is fixed, dst_count is pre-computed for monomial(dst.l)
// dst2_counts is pre-computed for monomial(dst.r).
// is dst_eq subsumed by src_eq?
bool ac_plugin::backward_subsumes(unsigned src_eq, unsigned dst_eq) {
auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq];
if (!can_be_subset(monomial(src.l), monomial(dst.l)))
return false;
if (!can_be_subset(monomial(src.r), monomial(dst.r)))
return false;
unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size();
if (size_diff != monomial(dst.r).size() - monomial(src.r).size())
return false;
if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l)))
return false;
if (!is_subset(m_dst_r_counts, m_src_r_counts, monomial(src.r)))
return false;
// add difference betwen src and dst1 to dst2
// (also add it to dst1 to make sure same difference isn't counted twice).
for (auto n : monomial(src.l)) {
unsigned id = n->root_id();
SASSERT(m_src_l_counts[id] >= m_dst_l_counts[id]);
unsigned diff = m_src_l_counts[id] - m_dst_l_counts[id];
if (diff > 0) {
m_dst_l_counts.inc(id, diff);
m_dst_r_counts.inc(id, diff);
}
}
// now dst2 and src2 should align and have the same elements.
// since src.r is a subset of dst.r we iterate over dst.r
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
}
NOT_IMPLEMENTED_YET();
// dst.l \ src.l = dst.r \ dst.r
return false;
// src_counts, src2_counts are initialized for src_eq
bool ac_plugin::forward_subsumes(unsigned src_eq, unsigned dst_eq) {
auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq];
if (!can_be_subset(monomial(src.l), monomial(dst.l)))
return false;
if (!can_be_subset(monomial(src.r), monomial(dst.r)))
return false;
unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size();
if (size_diff != monomial(dst.r).size() - monomial(src.r).size())
return false;
if (!is_superset(m_src_l_counts, m_dst_l_counts, monomial(dst.l)))
return false;
if (!is_subset(m_src_r_counts, m_dst_r_counts, monomial(dst.r)))
return false;
for (auto n : monomial(src.l)) {
unsigned id = n->root_id();
SASSERT(m_src_l_counts[id] >= m_dst_l_counts[id]);
unsigned diff = m_src_l_counts[id] - m_dst_l_counts[id];
if (diff > 0) {
m_dst_l_counts.inc(id, diff);
m_dst_r_counts.inc(id, diff);
}
}
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
}
unsigned ac_plugin::rewrite(monomial_t const& src_r, monomial_t const& dst_r) {
@ -573,33 +630,39 @@ namespace euf {
// add to m_src_r elements of dst.r that are not in src.l
for (auto n : dst_r) {
unsigned id = n->root_id();
unsigned count = m_src_count.get(id, 0);
unsigned count = m_src_l_counts[id];
if (count == 0)
m_src_r.push_back(n);
else
--m_src_count[id];
m_src_l_counts.inc(id, -1);
}
return to_monomial(nullptr, m_src_r);
}
bool ac_plugin::is_subset(monomial_t const& dst) {
reset_ids_counts(m_src_ids, m_src_count);
for (auto n : dst) {
// check that src is a subset of dst, where dst_counts are precomputed
bool ac_plugin::is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src) {
SASSERT(&dst_counts != &src_counts);
src_counts.reset();
for (auto n : src) {
unsigned id = n->root_id();
unsigned dst_count = m_dst_count.get(id, 0);
unsigned dst_count = dst_counts[id];
if (dst_count == 0)
return false;
else {
unsigned src_count = m_src_count.get(id, 0);
if (src_count >= dst_count)
return false;
else
m_src_count.set(id, src_count + 1), m_src_ids.push_back(id);
}
else if (src_counts[id] >= dst_count)
return false;
else
src_counts.inc(id, 1);
}
return true;
}
// check that dst is a superset of dst, where src_counts are precomputed
bool ac_plugin::is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst) {
SASSERT(&dst_counts != &src_counts);
init_ref_counts(dst, dst_counts);
return all_of(src_counts, [&](unsigned idx) { return dst_counts[idx] <= src_counts[idx]; });
}
void ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) {
if (src_eq == dst_eq)
return;
@ -609,7 +672,7 @@ namespace euf {
TRACE("plugin", tout << "superpose: "; display_equation(tout, src); tout << " "; display_equation(tout, dst); tout << "\n";);
// AB -> C, AD -> E => BE ~ CD
// m_src_ids, m_src_counts contains information about src (call it AD -> E)
reset_ids_counts(m_dst_ids, m_dst_count);
m_dst_l_counts.reset();
m_dst_r.reset();
m_dst_r.append(monomial(dst.r).m_nodes);
@ -622,63 +685,56 @@ namespace euf {
// compute BE, initialize dst_ids, dst_counts
for (auto n : monomial(dst.l)) {
unsigned id = n->root_id();
unsigned src_count = m_src_count.get(id, 0);
unsigned dst_count = m_dst_count.get(id, 0);
m_dst_count.set(id, dst_count + 1);
m_dst_ids.push_back(id);
if (src_count < dst_count)
m_src_r.push_back(n);
if (m_src_l_counts[id] < m_dst_l_counts[id])
m_src_r.push_back(n);
m_dst_l_counts.inc(id, 1);
}
// compute CD
for (auto n : monomial(src.l)) {
unsigned id = n->root_id();
unsigned dst_count = m_dst_count.get(id, 0);
if (dst_count > 0)
--m_dst_count[id];
if (m_dst_l_counts[id] > 0)
m_dst_l_counts.inc(id, -1);
else
m_dst_r.push_back(n);
}
// one side is a proper subset of the other
if (m_src_r.size() == src_r_size || m_dst_r.size() == dst_r_size) {
if (are_equal(m_src_r, m_dst_r)) {
m_src_r.shrink(src_r_size);
return;
}
if (m_src_r.size() == m_dst_r.size()) {
reset_ids_counts(m_dst_ids, m_dst_count);
bool are_equal = true;
for (auto n : m_dst_r) {
unsigned id = n->root_id();
unsigned dst_count = m_dst_count.get(id, 0);
m_dst_count.set(id, dst_count + 1);
}
for (auto n : m_src_r) {
unsigned id = n->root_id();
unsigned dst_count = m_dst_count.get(id, 0);
if (dst_count > 0)
m_dst_count[id]--;
else {
are_equal = false;
break;
}
}
if (are_equal) {
m_src_r.shrink(src_r_size);
return;
}
}
TRACE("plugin", for (auto n : m_src_r) tout << g.bpp(n->n) << " "; tout << "== "; for (auto n : m_dst_r) tout << g.bpp(n->n) << " "; tout << "\n";);
justification j = justify_rewrite(src_eq, dst_eq);
if (m_src_r.size() == 1 && m_dst_r.size() == 1)
push_merge(m_src_r[0]->n, m_dst_r[0]->n, j);
else
if (m_src_r.size() == 1 && m_dst_r.size() == 1)
push_merge(m_src_r[0]->n, m_dst_r[0]->n, j);
else
init_equation(eq(to_monomial(nullptr, m_src_r), to_monomial(nullptr, m_dst_r), j));
m_src_r.shrink(src_r_size);
}
bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) {
return filter(a) == filter(b) && are_equal(a.m_nodes, b.m_nodes);
}
bool ac_plugin::are_equal(ptr_vector<node> const& a, ptr_vector<node> const& b) {
if (a.size() != b.size())
return false;
m_eq_counts.reset();
for (auto n : a)
m_eq_counts.inc(n->root_id(), 1);
for (auto n : b) {
unsigned id = n->root_id();
if (m_eq_counts[id] == 0)
return false;
m_eq_counts.inc(id, -1);
}
return true;
}
//
// simple version based on propagating all shared
// todo: version touching only newly processed shared, and maintaining incremental data-structures.
@ -712,13 +768,13 @@ namespace euf {
while (change) {
change = false;
auto & m = monomial(s.m);
init_ids_counts(m, m_dst_ids, m_dst_count);
init_ref_counts(m, m_dst_l_counts);
init_overlap_iterator(UINT_MAX, m);
for (auto eq : m_eq_occurs) {
auto& src = m_eqs[eq];
if (!can_be_subset(monomial(src.l), m))
continue;
if (!is_subset(monomial(src.l)))
if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l)))
continue;
m_update_shared_trail.push_back({ idx, s });
push_undo(is_update_shared);

View file

@ -75,13 +75,16 @@ namespace euf {
uint64_t m_filter = 0;
};
enum eq_status {
processed, to_simplify, is_dead
};
// represent equalities added by merge_eh and by superposition
struct eq {
eq(unsigned l, unsigned r, justification j):
l(l), r(r), j(j) {}
unsigned l, r; // refer to monomials
bool is_processed = false; // true if the equality is in the processed set
bool is_alive = true;
eq_status status = to_simplify;
justification j; // justification for equality
};
@ -94,7 +97,7 @@ namespace euf {
struct monomial_t {
ptr_vector<node> m_nodes;
bloom m_filter;
bloom m_bloom;
node* operator[](unsigned i) const { return m_nodes[i]; }
unsigned size() const { return m_nodes.size(); }
node* const* begin() const { return m_nodes.begin(); }
@ -186,34 +189,50 @@ namespace euf {
bool is_sorted(monomial_t const& monomial) const;
uint64_t filter(monomial_t& m);
bool can_be_subset(monomial_t& subset, monomial_t& superset);
bool subsumes(unsigned src_eq, unsigned dst_eq);
bool are_equal(ptr_vector<node> const& a, ptr_vector<node> const& b);
bool are_equal(monomial_t& a, monomial_t& b);
bool backward_subsumes(unsigned src_eq, unsigned dst_eq);
bool forward_subsumes(unsigned src_eq, unsigned dst_eq);
void init_equation(eq const& e);
bool orient_equation(eq& e);
void set_processed(unsigned eq_id, bool f);
void set_status(unsigned eq_id, eq_status s);
unsigned pick_next_eq();
void forward_simplify(unsigned eq_id, unsigned using_eq);
void backward_simplify(unsigned eq_id, unsigned using_eq);
bool backward_simplify(unsigned eq_id, unsigned using_eq);
void superpose(unsigned src_eq, unsigned dst_eq);
ptr_vector<node> m_src_r, m_src_l, m_dst_r;
unsigned_vector m_src_ids, m_src_count, m_dst_ids, m_dst_count;
ptr_vector<node> m_src_r, m_src_l, m_dst_r, m_dst_l;
struct ref_counts {
unsigned_vector ids;
unsigned_vector counts;
void reset() { for (auto idx : ids) counts[idx] = 0; ids.reset(); }
unsigned operator[](unsigned idx) const { return counts.get(idx, 0); }
void inc(unsigned idx, unsigned amount) { counts.reserve(idx + 1, 0); ids.push_back(idx); counts[idx] += amount; }
unsigned const* begin() const { return ids.begin(); }
unsigned const* end() const { return ids.end(); }
};
ref_counts m_src_l_counts, m_dst_l_counts, m_src_r_counts, m_dst_r_counts, m_eq_counts;
unsigned_vector m_eq_occurs;
bool_vector m_eq_seen;
bool m_backward_simplified = false;
unsigned_vector const& forward_iterator(unsigned eq);
unsigned_vector const& superpose_iterator(unsigned eq);
unsigned_vector const& backward_iterator(unsigned eq);
void init_ids_counts(monomial_t const& monomial, unsigned_vector& ids, unsigned_vector& counts);
void reset_ids_counts(unsigned_vector& ids, unsigned_vector& counts);
void init_ref_counts(monomial_t const& monomial, ref_counts& counts);
void init_overlap_iterator(unsigned eq, monomial_t const& m);
bool is_subset(monomial_t const& dst);
// check that src is a subset of dst, where dst_counts are precomputed
bool is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src);
// check that dst is a superset of dst, where src_counts are precomputed
bool is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst);
unsigned rewrite(monomial_t const& src_r, monomial_t const& dst_r);
bool is_to_simplify(unsigned eq) const { return !m_eqs[eq].is_processed && m_eqs[eq].is_alive; }
bool is_processed(unsigned eq) const { return m_eqs[eq].is_processed && m_eqs[eq].is_alive; }
bool is_to_simplify(unsigned eq) const { return m_eqs[eq].status == eq_status::to_simplify; }
bool is_processed(unsigned eq) const { return m_eqs[eq].status == eq_status::processed; }
bool is_alive(unsigned eq) const { return m_eqs[eq].status != eq_status::is_dead; }
justification justify_rewrite(unsigned eq1, unsigned eq2);
justification::dependency* justify_equation(unsigned eq);