3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-07-29 15:37:58 +00:00

revamp ac plugin and plugin propagation

This commit is contained in:
Nikolaj Bjorner 2025-07-21 07:35:06 -07:00
parent b983524afc
commit dbcbc6c3ac
14 changed files with 630 additions and 215 deletions

View file

@ -20,11 +20,38 @@ Completion modulo AC
Add new equation zu = xyu = vy by j1, j2 Add new equation zu = xyu = vy by j1, j2
Notes: Sets P - processed, R - reductions, S - to simplify
- Some equalities come from shared terms, some do not.
new equality l = r:
reduce l = r modulo R if equation is external
orient l = r - if it cannot be oriented, discard
if l = r is a reduction rule then reduce R, S using l = r, insert into R
else insert into S
main loop:
for e as (l = r) in S:
remove e from S
backward simplify e
if e is backward subsumed, continue
if e is a reduction rule, then reduce R, S using e, insert into R, continue
insert e into P
superpose with e
forward simplify with e
backward simplify e as (l = r) using (l' = r') in P u S:
if l' is a subset or r then replace l' by r' in r.
backward subsumption e as (l = r) using (l' = r') in P u S:
l = r is of the form l'x = r'x
is reduction rule e as (l = r):
l is a unit, and r is unit, is empty, or is zero.
- V2 can use multiplicities of elements to handle larger domains. superpose e as (l = r) with (l' = r') in P:
- e.g. 3x + 100000y if l and l' share a common subset x.
forward simplify (l' = r') in P u S using e as (l = r):
More notes: More notes:
@ -107,6 +134,19 @@ namespace euf {
register_shared(arg); register_shared(arg);
} }
// unit -> {}
void ac_plugin::add_unit(enode* u) {
m_units.push_back(u);
auto n = mk_node(u);
auto m_id = to_monomial(u, {});
init_equation(eq(to_monomial(u), m_id, justification::axiom(get_id())));
}
// zero x -> zero
void ac_plugin::add_zero(enode* z) {
mk_node(z)->is_zero = true;
}
void ac_plugin::register_shared(enode* n) { void ac_plugin::register_shared(enode* n) {
if (m_shared_nodes.get(n->get_id(), false)) if (m_shared_nodes.get(n->get_id(), false))
return; return;
@ -144,16 +184,6 @@ namespace euf {
m_monomials.pop_back(); m_monomials.pop_back();
break; break;
} }
case is_merge_node: {
auto [other, old_shared, old_eqs] = m_merge_trail.back();
auto* root = other->root;
std::swap(other->next, root->next);
root->shared.shrink(old_shared);
root->eqs.shrink(old_eqs);
m_merge_trail.pop_back();
++m_tick;
break;
}
case is_update_eq: { case is_update_eq: {
auto const& [idx, eq] = m_update_eq_trail.back(); auto const& [idx, eq] = m_update_eq_trail.back();
m_eqs[idx] = eq; m_eqs[idx] = eq;
@ -226,6 +256,7 @@ namespace euf {
case eq_status::is_dead: out << "d"; break; case eq_status::is_dead: out << "d"; break;
case eq_status::processed: out << "p"; break; case eq_status::processed: out << "p"; break;
case eq_status::to_simplify: out << "s"; break; case eq_status::to_simplify: out << "s"; break;
case eq_status::is_reducing_eq: out << "r"; break;
} }
return out; return out;
} }
@ -234,15 +265,16 @@ namespace euf {
out << m_name << "\n"; out << m_name << "\n";
unsigned i = 0; unsigned i = 0;
for (auto const& eq : m_eqs) { for (auto const& eq : m_eqs) {
if (eq.status != eq_status::is_dead) if (eq.status != eq_status::is_dead) {
out << i << ": " << eq_pp_ll(*this, eq) << "\n"; out << "["; display_status(out, eq.status) << "] " << i << " : " << eq_pp_ll(*this, eq) << "\n";
}
++i; ++i;
} }
if (!m_shared.empty()) if (!m_shared.empty())
out << "shared monomials:\n"; out << "shared monomials:\n";
for (auto const& s : m_shared) { for (auto const& s : m_shared) {
out << g.bpp(s.n) << ": " << s.m << "\n"; out << g.bpp(s.n) << ": " << s.m << " r: " << g.bpp(s.n->get_root()) << "\n";
} }
#if 0 #if 0
i = 0; i = 0;
@ -274,13 +306,21 @@ namespace euf {
return out; return out;
} }
void ac_plugin::collect_statistics(statistics& st) const {
std::string name = m_name.str();
m_superposition_stats = symbol((std::string("ac ") + name + " superpositions"));
m_eqs_stats = symbol((std::string("ac ") + name + " equations"));
st.update(m_superposition_stats.bare_str(), m_stats.m_num_superpositions);
st.update(m_eqs_stats.bare_str(), m_eqs.size());
}
void ac_plugin::merge_eh(enode* l, enode* r) { void ac_plugin::merge_eh(enode* l, enode* r) {
if (l == r) if (l == r)
return; return;
m_fuel += m_fuel_inc;
auto j = justification::equality(l, r); auto j = justification::equality(l, r);
TRACE(plugin, tout << "merge: " << m_name << " " << g.bpp(l) << " == " << g.bpp(r) << " " << is_op(l) << " " << is_op(r) << "\n"); TRACE(plugin, tout << "merge: " << m_name << " " << g.bpp(l) << " == " << g.bpp(r) << " " << is_op(l) << " " << is_op(r) << "\n");
if (!is_op(l) && !is_op(r))
merge(mk_node(l), mk_node(r), j);
init_equation(eq(to_monomial(l), to_monomial(r), j)); init_equation(eq(to_monomial(l), to_monomial(r), j));
} }
@ -294,7 +334,7 @@ namespace euf {
register_shared(b); register_shared(b);
} }
void ac_plugin::init_equation(eq const& e) { bool ac_plugin::init_equation(eq const& e) {
m_eqs.push_back(e); m_eqs.push_back(e);
auto& eq = m_eqs.back(); auto& eq = m_eqs.back();
deduplicate(monomial(eq.l).m_nodes, monomial(eq.r).m_nodes); deduplicate(monomial(eq.l).m_nodes, monomial(eq.r).m_nodes);
@ -303,44 +343,53 @@ namespace euf {
auto& ml = monomial(eq.l); auto& ml = monomial(eq.l);
auto& mr = monomial(eq.r); auto& mr = monomial(eq.r);
if (ml.size() == 1 && mr.size() == 1)
push_merge(ml[0]->n, mr[0]->n, eq.j);
unsigned eq_id = m_eqs.size() - 1; unsigned eq_id = m_eqs.size() - 1;
if (ml.size() == 1 && mr.size() == 1)
push_merge(ml[0]->n, mr[0]->n, eq.j);
for (auto n : ml) { for (auto n : ml) {
if (!n->root->n->is_marked1()) { if (!n->n->is_marked2()) {
n->root->eqs.push_back(eq_id); n->eqs.push_back(eq_id);
n->root->n->mark1(); n->n->mark2();
push_undo(is_add_eq_index); push_undo(is_add_eq_index);
m_node_trail.push_back(n->root); m_node_trail.push_back(n);
for (auto s : n->root->shared) for (auto s : n->shared)
m_shared_todo.insert(s); m_shared_todo.insert(s);
} }
} }
for (auto n : mr) { for (auto n : mr) {
if (!n->root->n->is_marked1()) { if (!n->n->is_marked2()) {
n->root->eqs.push_back(eq_id); n->eqs.push_back(eq_id);
n->root->n->mark1(); n->n->mark2();
push_undo(is_add_eq_index); push_undo(is_add_eq_index);
m_node_trail.push_back(n->root); m_node_trail.push_back(n);
for (auto s : n->root->shared) for (auto s : n->shared)
m_shared_todo.insert(s); m_shared_todo.insert(s);
} }
} }
for (auto n : ml) for (auto n : ml)
n->root->n->unmark1(); n->n->unmark2();
for (auto n : mr) for (auto n : mr)
n->root->n->unmark1(); n->n->unmark2();
TRACE(plugin, display_equation_ll(tout, e) << " shared: " << m_shared_todo << "\n"); SASSERT(well_formed(eq));
TRACE(plugin, display_equation_ll(tout, eq) << " shared: " << m_shared_todo << "\n");
m_to_simplify_todo.insert(eq_id); m_to_simplify_todo.insert(eq_id);
m_new_eqs.push_back(eq_id);
//display_equation_ll(verbose_stream() << "init " << eq_id << ": ", eq) << "\n";
return true;
} }
else else {
m_eqs.pop_back(); m_eqs.pop_back();
return false;
}
} }
bool ac_plugin::orient_equation(eq& e) { bool ac_plugin::orient_equation(eq& e) {
@ -361,7 +410,7 @@ namespace euf {
if (ml[i]->id() < mr[i]->id()) if (ml[i]->id() < mr[i]->id())
std::swap(e.l, e.r); std::swap(e.l, e.r);
return true; return true;
} }
return false; return false;
} }
} }
@ -436,25 +485,6 @@ namespace euf {
return (filter(subset) | f2) == f2; return (filter(subset) | f2) == f2;
} }
void ac_plugin::merge(node* a, node* b, justification j) {
TRACE(plugin, tout << g.bpp(a->n) << " == " << g.bpp(b->n) << " num shared " << b->shared.size() << "\n");
if (a == b)
return;
if (a->id() < b->id())
std::swap(a, b);
for (auto n : equiv(a))
n->root = b;
m_merge_trail.push_back({ a, b->shared.size(), b->eqs.size() });
for (auto eq_id : a->eqs)
set_status(eq_id, eq_status::to_simplify);
for (auto m : a->shared)
m_shared_todo.insert(m);
b->shared.append(a->shared);
b->eqs.append(a->eqs);
std::swap(b->next, a->next);
push_undo(is_merge_node);
++m_tick;
}
void ac_plugin::push_undo(undo_kind k) { void ac_plugin::push_undo(undo_kind k) {
m_undo.push_back(k); m_undo.push_back(k);
@ -489,10 +519,21 @@ namespace euf {
ptr_buffer<expr> args; ptr_buffer<expr> args;
enode_vector nodes; enode_vector nodes;
for (auto arg : mon) { for (auto arg : mon) {
nodes.push_back(arg->root->n); nodes.push_back(arg->n);
args.push_back(arg->root->n->get_expr()); args.push_back(arg->n->get_expr());
}
expr* n = nullptr;
switch (args.size()) {
case 0:
UNREACHABLE();
break;
case 1:
n = args[0];
break;
default:
n = m.mk_app(m_fid, m_op, args.size(), args.data());
break;
} }
auto n = args.size() == 1 ? args[0] : m.mk_app(m_fid, m_op, args.size(), args.data());
auto r = g.find(n); auto r = g.find(n);
return r ? r : g.mk(n, 0, nodes.size(), nodes.data()); return r ? r : g.mk(n, 0, nodes.size(), nodes.data());
} }
@ -501,8 +542,6 @@ namespace euf {
auto* mem = r.allocate(sizeof(node)); auto* mem = r.allocate(sizeof(node));
node* res = new (mem) node(); node* res = new (mem) node();
res->n = n; res->n = n;
res->root = res;
res->next = res;
return res; return res;
} }
@ -521,13 +560,22 @@ namespace euf {
} }
void ac_plugin::propagate() { void ac_plugin::propagate() {
//verbose_stream() << "propagate " << m_name << "\n";
unsigned ts = m_to_simplify_todo.size();
unsigned round = 0;
while (true) { while (true) {
loop_start: loop_start:
//verbose_stream() << "loop_start " << (round++) << " " << m_to_simplify_todo.size() << " ts: " << ts << "\n";
if (m_fuel == 0)
break;
unsigned eq_id = pick_next_eq(); unsigned eq_id = pick_next_eq();
if (eq_id == UINT_MAX) if (eq_id == UINT_MAX)
break; break;
TRACE(plugin, tout << "propagate " << eq_id << ": " << eq_pp_ll(*this, m_eqs[eq_id]) << "\n"); TRACE(plugin, tout << "propagate " << eq_id << ": " << eq_pp_ll(*this, m_eqs[eq_id]) << "\n");
//verbose_stream() << m_name << " propagate eq " << eq_id << ": " << eq_pp_ll(*this, m_eqs[eq_id]) << "\n";
SASSERT(well_formed(m_eqs[eq_id]));
// simplify eq using processed // simplify eq using processed
TRACE(plugin, TRACE(plugin,
@ -543,26 +591,41 @@ namespace euf {
set_status(eq_id, eq_status::is_dead); set_status(eq_id, eq_status::is_dead);
continue; continue;
} }
if (is_backward_subsumed(eq_id)) {
set_status(eq_id, eq_status::is_dead);
continue;
}
if (is_reducing(eq)) {
set_status(eq_id, eq_status::is_reducing_eq);
forward_reduce(eq_id);
continue;
}
--m_fuel;
set_status(eq_id, eq_status::processed); set_status(eq_id, eq_status::processed);
// simplify processed using eq // simplify processed using eq
for (auto other_eq : forward_iterator(eq_id)) for (auto other_eq : forward_iterator(eq_id))
if (is_processed(other_eq)) if (is_processed(other_eq) || is_reducing(other_eq))
forward_simplify(eq_id, other_eq); forward_simplify(eq_id, other_eq);
backward_subsume_new_eqs();
// superpose, create new equations // superpose, create new equations
unsigned new_eqs = 0; unsigned new_sup = 0;
m_new_eqs.reset();
for (auto other_eq : superpose_iterator(eq_id)) for (auto other_eq : superpose_iterator(eq_id))
if (is_processed(other_eq)) if (is_processed(other_eq))
new_eqs += superpose(eq_id, other_eq); new_sup += superpose(eq_id, other_eq);
backward_subsume_new_eqs();
(void)new_eqs; m_stats.m_num_superpositions += new_sup;
TRACE(plugin, tout << "added eqs " << new_eqs << "\n"); TRACE(plugin, tout << "new superpositions " << new_sup << "\n");
// simplify to_simplify using eq // simplify to_simplify using eq
for (auto other_eq : forward_iterator(eq_id)) for (auto other_eq : forward_iterator(eq_id))
if (is_to_simplify(other_eq)) if (is_to_simplify(other_eq))
forward_simplify(eq_id, other_eq); forward_simplify(eq_id, other_eq);
backward_subsume_new_eqs();
} }
propagate_shared(); propagate_shared();
@ -584,7 +647,7 @@ namespace euf {
auto& eq = m_eqs[id]; auto& eq = m_eqs[id];
if (eq.status == eq_status::is_dead) if (eq.status == eq_status::is_dead)
return; return;
if (s == eq_status::to_simplify && are_equal(monomial(eq.l), monomial(eq.r))) if (are_equal(monomial(eq.l), monomial(eq.r)))
s = eq_status::is_dead; s = eq_status::is_dead;
if (eq.status != s) { if (eq.status != s) {
@ -594,12 +657,15 @@ namespace euf {
} }
switch (s) { switch (s) {
case eq_status::processed: case eq_status::processed:
case eq_status::is_reducing_eq:
case eq_status::is_dead: case eq_status::is_dead:
m_to_simplify_todo.remove(id); m_to_simplify_todo.remove(id);
break; break;
case eq_status::to_simplify: case eq_status::to_simplify:
m_to_simplify_todo.insert(id); m_to_simplify_todo.insert(id);
orient_equation(eq); if (!orient_equation(eq)) {
set_status(id, eq_status::is_dead);
}
break; break;
} }
} }
@ -624,6 +690,12 @@ namespace euf {
auto const& eq = m_eqs[eq_id]; auto const& eq = m_eqs[eq_id];
init_ref_counts(monomial(eq.r), m_dst_r_counts); init_ref_counts(monomial(eq.r), m_dst_r_counts);
init_ref_counts(monomial(eq.l), m_dst_l_counts); init_ref_counts(monomial(eq.l), m_dst_l_counts);
if (monomial(eq.r).size() == 0) {
m_dst_l.reset();
m_dst_l.append(monomial(eq.l).m_nodes);
init_subset_iterator(eq_id, monomial(eq.l));
return m_eq_occurs;
}
m_dst_r.reset(); m_dst_r.reset();
m_dst_r.append(monomial(eq.r).m_nodes); m_dst_r.append(monomial(eq.r).m_nodes);
init_subset_iterator(eq_id, monomial(eq.r)); init_subset_iterator(eq_id, monomial(eq.r));
@ -633,7 +705,7 @@ namespace euf {
void ac_plugin::init_overlap_iterator(unsigned eq_id, monomial_t const& m) { void ac_plugin::init_overlap_iterator(unsigned eq_id, monomial_t const& m) {
m_eq_occurs.reset(); m_eq_occurs.reset();
for (auto n : m) for (auto n : m)
m_eq_occurs.append(n->root->eqs); m_eq_occurs.append(n->eqs);
compress_eq_occurs(eq_id); compress_eq_occurs(eq_id);
} }
@ -649,17 +721,17 @@ namespace euf {
node* max_n = nullptr; node* max_n = nullptr;
bool has_two = false; bool has_two = false;
for (auto n : m) for (auto n : m)
if (n->root->eqs.size() >= max_use) if (n->eqs.size() >= max_use)
has_two |= max_n && (max_n != n->root), max_n = n->root, max_use = n->root->eqs.size(); has_two |= max_n && (max_n != n), max_n = n, max_use = n->eqs.size();
m_eq_occurs.reset(); m_eq_occurs.reset();
if (has_two) { if (has_two) {
for (auto n : m) for (auto n : m)
if (n->root != max_n) if (n != max_n)
m_eq_occurs.append(n->root->eqs); m_eq_occurs.append(n->eqs);
} }
else { else {
for (auto n : m) { for (auto n : m) {
m_eq_occurs.append(n->root->eqs); m_eq_occurs.append(n->eqs);
break; break;
} }
} }
@ -676,6 +748,8 @@ namespace euf {
continue; continue;
if (id == eq_id) if (id == eq_id)
continue; continue;
if (!is_alive(id))
continue;
m_eq_occurs[j++] = id; m_eq_occurs[j++] = id;
m_eq_seen[id] = true; m_eq_seen[id] = true;
} }
@ -696,8 +770,8 @@ namespace euf {
unsigned min_r = UINT_MAX; unsigned min_r = UINT_MAX;
node* min_n = nullptr; node* min_n = nullptr;
for (auto n : monomial(eq.l)) for (auto n : monomial(eq.l))
if (n->root->eqs.size() < min_r) if (n->eqs.size() < min_r)
min_n = n, min_r = n->root->eqs.size(); min_n = n, min_r = n->eqs.size();
// found node that occurs in fewest eqs // found node that occurs in fewest eqs
VERIFY(min_n); VERIFY(min_n);
return min_n->eqs; return min_n->eqs;
@ -722,7 +796,7 @@ namespace euf {
init_ref_counts(m, check); init_ref_counts(m, check);
return return
all_of(counts, [&](unsigned i) { return check[i] == counts[i]; }) && all_of(counts, [&](unsigned i) { return check[i] == counts[i]; }) &&
all_of(check, [&](unsigned i) { return check[i] == counts[i]; }); all_of(check, [&](unsigned i) { return check[i] == counts[i]; });
} }
void ac_plugin::forward_simplify(unsigned src_eq, unsigned dst_eq) { void ac_plugin::forward_simplify(unsigned src_eq, unsigned dst_eq) {
@ -737,10 +811,10 @@ namespace euf {
auto& src = m_eqs[src_eq]; // src_r_counts, src_l_counts are initialized auto& src = m_eqs[src_eq]; // src_r_counts, src_l_counts are initialized
auto& dst = m_eqs[dst_eq]; auto& dst = m_eqs[dst_eq];
TRACE(plugin, tout << "forward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << "\n"); TRACE(plugin_verbose, tout << "forward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << "\n");
if (forward_subsumes(src_eq, dst_eq)) { if (forward_subsumes(src_eq, dst_eq)) {
TRACE(plugin, tout << "forward subsumed\n"); TRACE(plugin_verbose, tout << "forward subsumed\n");
set_status(dst_eq, eq_status::is_dead); set_status(dst_eq, eq_status::is_dead);
return; return;
} }
@ -761,18 +835,14 @@ namespace euf {
unsigned num_overlap = 0; unsigned num_overlap = 0;
for (auto n : monomial(dst.r)) { for (auto n : monomial(dst.r)) {
unsigned id = n->id(); unsigned id = n->id();
unsigned dst_count = m_dst_r_counts[id];
unsigned src_count = m_src_l_counts[id]; unsigned src_count = m_src_l_counts[id];
if (dst_count > src_count) { unsigned dst_count = m_dst_r_counts[id];
m_src_r.push_back(n); if (dst_count < src_count) {
m_dst_r_counts.dec(id, 1); m_dst_r_counts.inc(id, 1);
} ++num_overlap;
else if (dst_count < src_count) {
m_src_r.shrink(src_r_size);
return;
} }
else else
++num_overlap; m_src_r.push_back(n);
} }
// The dst.r has to be a superset of src.l, otherwise simplification does not apply // The dst.r has to be a superset of src.l, otherwise simplification does not apply
if (num_overlap != src_l_size) { if (num_overlap != src_l_size) {
@ -789,7 +859,8 @@ namespace euf {
push_undo(is_update_eq); push_undo(is_update_eq);
m_src_r.reset(); m_src_r.reset();
m_src_r.append(monomial(src.r).m_nodes); m_src_r.append(monomial(src.r).m_nodes);
TRACE(plugin, tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n"); TRACE(plugin_verbose, tout << "rewritten to " << m_pp_ll(*this, monomial(new_r)) << "\n");
m_new_eqs.push_back(dst_eq);
} }
bool ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) { bool ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) {
@ -804,22 +875,26 @@ namespace euf {
TRACE(plugin, tout << "backward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n"); TRACE(plugin, tout << "backward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n");
if (backward_subsumes(src_eq, dst_eq)) { if (backward_subsumes(src_eq, dst_eq)) {
TRACE(plugin, tout << "backward subsumed\n");
set_status(dst_eq, eq_status::is_dead); set_status(dst_eq, eq_status::is_dead);
return true; return true;
} }
if (!is_equation_oriented(src))
return false;
// check that src.l is a subset of dst.r // check that src.l is a subset of dst.r
if (!can_be_subset(monomial(src.l), monomial(dst.r))) if (!can_be_subset(monomial(src.l), monomial(dst.r)))
return false; return false;
if (!is_subset(m_dst_r_counts, m_src_l_counts, monomial(src.l))) { if (!is_subset(m_dst_r_counts, m_src_l_counts, monomial(src.l)))
TRACE(plugin, tout << "not subset\n"); return false;
return false; if (monomial(dst.r).size() == 0)
} return false;
SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts));
ptr_vector<node> m(m_dst_r); ptr_vector<node> m(m_dst_r);
init_ref_counts(monomial(src.l), m_src_l_counts); init_ref_counts(monomial(src.l), m_src_l_counts);
//verbose_stream() << "backward simplify " << eq_pp_ll(*this, src_eq) << " for " << eq_pp_ll(*this, dst_eq) << "\n";
rewrite1(m_src_l_counts, monomial(src.r), m_dst_r_counts, m); rewrite1(m_src_l_counts, monomial(src.r), m_dst_r_counts, m);
auto j = justify_rewrite(src_eq, dst_eq); auto j = justify_rewrite(src_eq, dst_eq);
@ -831,30 +906,60 @@ namespace euf {
m_eqs[dst_eq].j = j; m_eqs[dst_eq].j = j;
TRACE(plugin, tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n"); TRACE(plugin, tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n");
push_undo(is_update_eq); push_undo(is_update_eq);
return true; return true;
} }
void ac_plugin::backward_subsume_new_eqs() {
for (auto f_id : m_new_eqs)
if (is_backward_subsumed(f_id))
set_status(f_id, eq_status::is_dead);
m_new_eqs.reset();
}
bool ac_plugin::is_backward_subsumed(unsigned eq_id) {
return any_of(backward_iterator(eq_id), [&](unsigned other_eq) { return backward_subsumes(other_eq, eq_id); });
}
// dst_eq is fixed, dst_l_count is pre-computed for monomial(dst.l) // dst_eq is fixed, dst_l_count is pre-computed for monomial(dst.l)
// dst_r_counts is pre-computed for monomial(dst.r). // dst_r_counts is pre-computed for monomial(dst.r).
// is dst_eq subsumed by src_eq? // is dst_eq subsumed by src_eq?
bool ac_plugin::backward_subsumes(unsigned src_eq, unsigned dst_eq) { bool ac_plugin::backward_subsumes(unsigned src_eq, unsigned dst_eq) {
auto& src = m_eqs[src_eq]; auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq]; auto& dst = m_eqs[dst_eq];
TRACE(plugin_verbose, tout << "backward subsumes " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << "\n");
SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts)); SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts));
SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts)); SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts));
if (!can_be_subset(monomial(src.l), monomial(dst.l))) if (!can_be_subset(monomial(src.l), monomial(dst.l))) {
TRACE(plugin_verbose, tout << "not subset of dst.l\n");
SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq]));
return false; return false;
if (!can_be_subset(monomial(src.r), monomial(dst.r))) }
if (!can_be_subset(monomial(src.r), monomial(dst.r))) {
TRACE(plugin_verbose, tout << "not subset of dst.r\n");
SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq]));
return false; return false;
}
unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size(); unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size();
if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) {
TRACE(plugin_verbose, tout << "size diff does not match: " << size_diff << "\n");
SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq]));
return false; return false;
if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) }
if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) {
TRACE(plugin_verbose, tout << "not subset of dst.l counts\n");
SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq]));
return false; return false;
if (!is_subset(m_dst_r_counts, m_src_r_counts, monomial(src.r))) }
return false; if (!is_subset(m_dst_r_counts, m_src_r_counts, monomial(src.r))) {
TRACE(plugin_verbose, tout << "not subset of dst.r counts\n");
SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq]));
return false;
}
SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts)); SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts));
SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts)); SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts));
SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts));
SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts));
// add difference betwen dst.l and src.l to both src.l, src.r // add difference betwen dst.l and src.l to both src.l, src.r
for (auto n : monomial(dst.l)) { for (auto n : monomial(dst.l)) {
unsigned id = n->id(); unsigned id = n->id();
@ -867,6 +972,13 @@ namespace euf {
} }
// now dst.r and src.r should align and have the same elements. // now dst.r and src.r should align and have the same elements.
// since src.r is a subset of dst.r we iterate over dst.r // since src.r is a subset of dst.r we iterate over dst.r
if (!all_of(monomial(src.r), [&](node* n) {
unsigned id = n->id();
return m_src_r_counts[id] == m_dst_r_counts[id]; })) {
TRACE(plugin_verbose, tout << "dst.r and src.r do not align\n");
SASSERT(!are_equal(m_eqs[src_eq], m_eqs[dst_eq]));
return false;
}
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
} }
@ -937,27 +1049,27 @@ namespace euf {
bool change = false; bool change = false;
unsigned sz = m.size(); unsigned sz = m.size();
unsigned jj = 0; unsigned jj = 0;
//verbose_stream() << "start\n";
do { do {
init_loop: init_loop:
//verbose_stream() << "loop " << jj++ << "\n";
if (m.size() == 1) if (m.size() == 1)
return change; return change;
bloom b; bloom b;
init_ref_counts(m, m_m_counts); init_ref_counts(m, m_m_counts);
unsigned k = 0; unsigned k = 0;
for (auto n : m) { for (auto n : m) {
//verbose_stream() << "inner loop " << k++ << "\n"; if (n->is_zero) {
for (auto eq : n->root->eqs) { m[0] = n;
m.shrink(1);
break;
}
for (auto eq : n->eqs) {
if (!is_processed(eq)) if (!is_processed(eq))
continue; continue;
auto& src = m_eqs[eq]; auto& src = m_eqs[eq];
if (!is_equation_oriented(src)) { if (!is_equation_oriented(src)) {
//verbose_stream() << "equation is not oriented: " << m_eq_ll(*this, src) << "\n";
continue; continue;
if (!orient_equation(src))
continue;
// deduplicate(src.l, src.r);
} }
if (!can_be_subset(monomial(src.l), m, b)) if (!can_be_subset(monomial(src.l), m, b))
continue; continue;
@ -997,18 +1109,19 @@ namespace euf {
void ac_plugin::index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r) { void ac_plugin::index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r) {
for (auto n : old_r) for (auto n : old_r)
n->root->n->mark1(); n->n->mark2();
for (auto n : new_r) for (auto n : new_r) {
if (!n->root->n->is_marked1()) { if (!n->n->is_marked2()) {
n->root->eqs.push_back(eq); n->eqs.push_back(eq);
m_node_trail.push_back(n->root); m_node_trail.push_back(n);
n->root->n->mark1(); n->n->mark2();
push_undo(is_add_eq_index); push_undo(is_add_eq_index);
} }
}
for (auto n : old_r) for (auto n : old_r)
n->root->n->unmark1(); n->n->unmark2();
for (auto n : new_r) for (auto n : new_r)
n->root->n->unmark1(); n->n->unmark2();
} }
@ -1018,10 +1131,9 @@ namespace euf {
auto& src = m_eqs[src_eq]; auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq]; auto& dst = m_eqs[dst_eq];
unsigned max_left = std::max(monomial(src.l).size(), monomial(dst.l).size()); unsigned max_left = std::max(monomial(src.l).size(), monomial(dst.l).size());
unsigned min_right = std::max(monomial(src.r).size(), monomial(dst.r).size()); unsigned min_right = std::max(monomial(src.r).size(), monomial(dst.r).size());
TRACE(plugin, tout << "superpose: "; display_equation_ll(tout, src); tout << " "; display_equation_ll(tout, dst); tout << "\n";); TRACE(plugin, tout << "superpose: "; display_equation_ll(tout, src); tout << " "; display_equation_ll(tout, dst); tout << "\n";);
// AB -> C, AD -> E => BE ~ CD // AB -> C, AD -> E => BE ~ CD
// m_src_ids, m_src_counts contains information about src (call it AD -> E) // m_src_ids, m_src_counts contains information about src (call it AD -> E)
@ -1066,24 +1178,81 @@ namespace euf {
TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";); TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";);
justification j = justify_rewrite(src_eq, dst_eq); justification j = justify_rewrite(src_eq, dst_eq);
deduplicate(m_src_r, m_dst_r);
reduce(m_dst_r, j); reduce(m_dst_r, j);
reduce(m_src_r, j); reduce(m_src_r, j);
deduplicate(m_src_r, m_dst_r);
TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";); TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";);
bool added_eq = false; bool added_eq = false;
auto src_r = src.r;
unsigned max_left_new = std::max(m_src_r.size(), m_dst_r.size()); unsigned max_left_new = std::max(m_src_r.size(), m_dst_r.size());
unsigned min_right_new = std::min(m_src_r.size(), m_dst_r.size()); unsigned min_right_new = std::min(m_src_r.size(), m_dst_r.size());
if (max_left_new <= max_left && min_right_new <= min_right) { if (max_left_new <= max_left && min_right_new <= min_right)
init_equation(eq(to_monomial(m_src_r), to_monomial(m_dst_r), j)); added_eq = init_equation(eq(to_monomial(m_src_r), to_monomial(m_dst_r), j));
added_eq = true;
}
m_src_r.reset(); m_src_r.reset();
m_src_r.append(monomial(src.r).m_nodes); m_src_r.append(monomial(src_r).m_nodes);
return added_eq; return added_eq;
} }
bool ac_plugin::is_reducing(eq const& e) const {
auto const& l = monomial(e.l);
auto const& r = monomial(e.r);
return l.size() == 1 && r.size() <= 1;
}
void ac_plugin::forward_reduce(unsigned eq_id) {
auto const& eq = m_eqs[eq_id];
if (!is_reducing(eq))
return;
for (auto other_eq : superpose_iterator(eq_id)) {
SASSERT(is_alive(other_eq));
forward_reduce(eq, other_eq);
}
}
void ac_plugin::forward_reduce(eq const& eq, unsigned other_eq_id) {
auto& other_eq = m_eqs[other_eq_id];
bool change = false;
if (forward_reduce_monomial(eq, monomial(other_eq.l)))
change = true;
if (forward_reduce_monomial(eq, monomial(other_eq.r)))
change = true;
if (change)
set_status(other_eq_id, eq_status::to_simplify);
}
bool ac_plugin::forward_reduce_monomial(eq const& eq, monomial_t& m) {
auto const& r = monomial(eq.r);
unsigned j = 0;
bool change = false;
for (auto n : m) {
unsigned id = n->id();
SASSERT(m_src_l_counts[id] <= 1);
if (m_src_l_counts[id] == 0) {
m.m_nodes[j++] = n;
continue;
}
change = true;
if (r.size() == 0)
// if r is empty, we can remove n from l
continue;
SASSERT(r.size() == 1);
if (r[0]->is_zero) {
m.m_nodes[0] = r[0];
j = 1;
break;
}
m.m_nodes[j++] = r[0];
}
m.m_nodes.shrink(j);
return change;
}
bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) { bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) {
return filter(a) == filter(b) && are_equal(a.m_nodes, b.m_nodes); return filter(a) == filter(b) && are_equal(a.m_nodes, b.m_nodes);
} }
@ -1104,7 +1273,42 @@ namespace euf {
return true; return true;
} }
bool ac_plugin::well_formed(eq const& e) const {
if (e.l == e.r)
return false; // trivial equation
for (auto n : monomial(e.l)) {
if (n->is_zero && monomial(e.l).size() > 1)
return false; // zero is not allowed in equations
}
for (auto n : monomial(e.r)) {
if (n->is_zero && monomial(e.r).size() > 1)
return false; // zero is not allowed in equations
}
return true;
}
void ac_plugin::deduplicate(ptr_vector<node>& a, ptr_vector<node>& b) { void ac_plugin::deduplicate(ptr_vector<node>& a, ptr_vector<node>& b) {
{
unsigned j = 0;
for (auto n : a) {
if (n->is_zero) {
//verbose_stream() << "deduplicate: removing zero from a: " << m_pp(*this, a) << "\n";
a[0] = n;
a.shrink(1);
break;
}
}
j = 0;
for (auto n : b) {
if (n->is_zero) {
// verbose_stream() << "deduplicate: removing zero from b: " << m_pp(*this, b) << "\n";
b[0] = n;
b.shrink(1);
break;
}
}
}
if (!m_is_injective) if (!m_is_injective)
return; return;
m_eq_counts.reset(); m_eq_counts.reset();
@ -1144,7 +1348,7 @@ namespace euf {
// //
void ac_plugin::propagate_shared() { void ac_plugin::propagate_shared() {
TRACE(plugin, tout << "num shared todo " << m_shared_todo.size() << "\n"); TRACE(plugin_verbose, tout << "num shared todo " << m_shared_todo.size() << "\n");
if (m_shared_todo.empty()) if (m_shared_todo.empty())
return; return;
while (!m_shared_todo.empty()) { while (!m_shared_todo.empty()) {
@ -1156,11 +1360,11 @@ namespace euf {
m_monomial_table.reset(); m_monomial_table.reset();
for (auto const& s1 : m_shared) { for (auto const& s1 : m_shared) {
shared s2; shared s2;
TRACE(plugin, tout << "shared " << s1.m << ": " << m_pp_ll(*this, monomial(s1.m)) << "\n"); TRACE(plugin_verbose, tout << "shared " << s1.m << ": " << m_pp_ll(*this, monomial(s1.m)) << "\n");
if (!m_monomial_table.find(s1.m, s2)) if (!m_monomial_table.find(s1.m, s2))
m_monomial_table.insert(s1.m, s1); m_monomial_table.insert(s1.m, s1);
else if (s2.n->get_root() != s1.n->get_root()) { else if (s2.n->get_root() != s1.n->get_root()) {
TRACE(plugin, tout << m_pp(*this, monomial(s1.m)) << " == " << m_pp(*this, monomial(s2.m)) << "\n"); TRACE(plugin, tout << "merge shared " << g.bpp(s1.n->get_root()) << " and " << g.bpp(s2.n->get_root()) << "\n");
push_merge(s1.n, s2.n, justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(s1.j), m_dep_manager.mk_leaf(s2.j)))); push_merge(s1.n, s2.n, justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(s1.j), m_dep_manager.mk_leaf(s2.j))));
} }
} }
@ -1171,31 +1375,32 @@ namespace euf {
auto old_m = s.m; auto old_m = s.m;
auto old_n = monomial(old_m).m_src; auto old_n = monomial(old_m).m_src;
ptr_vector<node> m1(monomial(old_m).m_nodes); ptr_vector<node> m1(monomial(old_m).m_nodes);
TRACE(plugin, tout << "simplify shared: " << g.bpp(old_n) << ": " << m_pp_ll(*this, monomial(old_m)) << "\n"); TRACE(plugin_verbose, tout << "simplify shared: " << g.bpp(old_n) << ": " << m_pp_ll(*this, monomial(old_m)) << "\n");
if (!reduce(m1, j)) if (!reduce(m1, j))
return; return;
auto new_n = from_monomial(m1); enode* new_n = nullptr;
new_n = m1.empty() ? get_unit(s.n) : from_monomial(m1);
auto new_m = to_monomial(new_n, m1); auto new_m = to_monomial(new_n, m1);
// update shared occurrences for members of the new monomial that are not already in the old monomial. // update shared occurrences for members of the new monomial that are not already in the old monomial.
for (auto n : monomial(old_m)) for (auto n : monomial(old_m))
n->root->n->mark1(); n->n->mark2();
for (auto n : m1) { for (auto n : m1) {
if (!n->root->n->is_marked1()) { if (!n->n->is_marked2()) {
n->root->shared.push_back(idx); n->shared.push_back(idx);
m_shared_todo.insert(idx); m_shared_todo.insert(idx);
m_node_trail.push_back(n->root); m_node_trail.push_back(n);
push_undo(is_add_shared_index); push_undo(is_add_shared_index);
} }
} }
for (auto n : monomial(old_m)) for (auto n : monomial(old_m))
n->root->n->unmark1(); n->n->unmark2();
m_update_shared_trail.push_back({ idx, s }); m_update_shared_trail.push_back({ idx, s });
push_undo(is_update_shared); push_undo(is_update_shared);
m_shared[idx].m = new_m; m_shared[idx].m = new_m;
m_shared[idx].j = j; m_shared[idx].j = j;
TRACE(plugin, tout << "shared simplified to " << m_pp_ll(*this, monomial(new_m)) << "\n"); TRACE(plugin_verbose, tout << "shared simplified to " << m_pp_ll(*this, monomial(new_m)) << "\n");
push_merge(old_n, new_n, j); push_merge(old_n, new_n, j);
} }
@ -1215,8 +1420,8 @@ namespace euf {
justification::dependency* ac_plugin::justify_monomial(justification::dependency* j, monomial_t const& m) { justification::dependency* ac_plugin::justify_monomial(justification::dependency* j, monomial_t const& m) {
for (auto n : m) for (auto n : m)
if (n->root->n != n->n) if (n->n != n->n)
j = m_dep_manager.mk_join(j, m_dep_manager.mk_leaf(justification::equality(n->root->n, n->n))); j = m_dep_manager.mk_join(j, m_dep_manager.mk_leaf(justification::equality(n->n, n->n)));
return j; return j;
} }

View file

@ -36,37 +36,19 @@ namespace euf {
class ac_plugin : public plugin { class ac_plugin : public plugin {
// enode structure for AC equivalences struct stats {
struct node { unsigned m_num_superpositions = 0;// number of superpositions
enode* n; // associated enode
node* root; // path compressed root
node* next; // next in equivalence class
justification j; // justification for equality
node* target = nullptr; // justified next
unsigned_vector shared; // shared occurrences
unsigned_vector eqs; // equality occurrences
unsigned id() const { return root->n->get_id(); }
static node* mk(region& r, enode* n);
}; };
class equiv { // enode structure for AC equivalences
node& n; struct node {
public: enode* n; // associated enode
class iterator { unsigned_vector shared; // shared occurrences
node* m_first; unsigned_vector eqs; // equality occurrences
node* m_last; bool is_zero = false;
public:
iterator(node* n, node* m) : m_first(n), m_last(m) {} unsigned id() const { return n->get_id(); }
node* operator*() { return m_first; } static node* mk(region& r, enode* n);
iterator& operator++() { if (!m_last) m_last = m_first; m_first = m_first->next; return *this; }
iterator operator++(int) { iterator tmp = *this; ++*this; return tmp; }
bool operator!=(iterator const& other) const { return m_last != other.m_last || m_first != other.m_first; }
};
equiv(node& _n) :n(_n) {}
equiv(node* _n) :n(*_n) {}
iterator begin() const { return iterator(&n, nullptr); }
iterator end() const { return iterator(&n, &n); }
}; };
struct bloom { struct bloom {
@ -75,7 +57,7 @@ namespace euf {
}; };
enum eq_status { enum eq_status {
processed, to_simplify, is_dead processed, to_simplify, is_reducing_eq, is_dead
}; };
// represent equalities added by merge_eh and by superposition // represent equalities added by merge_eh and by superposition
@ -150,6 +132,10 @@ namespace euf {
tracked_uint_set m_shared_todo; tracked_uint_set m_shared_todo;
uint64_t m_tick = 1; uint64_t m_tick = 1;
symbol m_name; symbol m_name;
unsigned m_fuel = 0;
unsigned m_fuel_inc = 3;
stats m_stats;
mutable symbol m_superposition_stats, m_eqs_stats;
@ -163,7 +149,6 @@ namespace euf {
is_add_eq, is_add_eq,
is_add_monomial, is_add_monomial,
is_add_node, is_add_node,
is_merge_node,
is_update_eq, is_update_eq,
is_add_shared_index, is_add_shared_index,
is_add_eq_index, is_add_eq_index,
@ -200,14 +185,35 @@ namespace euf {
bool can_be_subset(monomial_t& subset, ptr_vector<node> const& m, bloom& b); bool can_be_subset(monomial_t& subset, ptr_vector<node> const& m, bloom& b);
bool are_equal(ptr_vector<node> const& a, ptr_vector<node> const& b); bool are_equal(ptr_vector<node> const& a, ptr_vector<node> const& b);
bool are_equal(monomial_t& a, monomial_t& b); bool are_equal(monomial_t& a, monomial_t& b);
bool are_equal(eq const& a, eq const& b) {
return are_equal(monomial(a.l), monomial(b.l)) && are_equal(monomial(a.r), monomial(b.r));
}
bool well_formed(eq const& e) const;
bool is_reducing(eq const& e) const;
void forward_reduce(unsigned eq_id);
void forward_reduce(eq const& src, unsigned dst);
bool forward_reduce_monomial(eq const& eq, monomial_t& m);
void backward_subsume_new_eqs();
bool is_backward_subsumed(unsigned dst_eq);
bool backward_subsumes(unsigned src_eq, unsigned dst_eq); bool backward_subsumes(unsigned src_eq, unsigned dst_eq);
bool forward_subsumes(unsigned src_eq, unsigned dst_eq); bool forward_subsumes(unsigned src_eq, unsigned dst_eq);
void init_equation(eq const& e); enode_vector m_units;
enode* get_unit(enode* n) const {
for (auto u : m_units) {
if (u->get_sort() == n->get_sort())
return u;
}
UNREACHABLE();
return nullptr;
}
bool init_equation(eq const& e);
bool orient_equation(eq& e); bool orient_equation(eq& e);
void set_status(unsigned eq_id, eq_status s); void set_status(unsigned eq_id, eq_status s);
unsigned pick_next_eq(); unsigned pick_next_eq();
unsigned_vector m_new_eqs;
void forward_simplify(unsigned eq_id, unsigned using_eq); void forward_simplify(unsigned eq_id, unsigned using_eq);
bool backward_simplify(unsigned eq_id, unsigned using_eq); bool backward_simplify(unsigned eq_id, unsigned using_eq);
bool superpose(unsigned src_eq, unsigned dst_eq); bool superpose(unsigned src_eq, unsigned dst_eq);
@ -249,6 +255,7 @@ namespace euf {
bool is_to_simplify(unsigned eq) const { return m_eqs[eq].status == eq_status::to_simplify; } bool is_to_simplify(unsigned eq) const { return m_eqs[eq].status == eq_status::to_simplify; }
bool is_processed(unsigned eq) const { return m_eqs[eq].status == eq_status::processed; } bool is_processed(unsigned eq) const { return m_eqs[eq].status == eq_status::processed; }
bool is_reducing(unsigned eq) const { return m_eqs[eq].status == eq_status::is_reducing_eq; }
bool is_alive(unsigned eq) const { return m_eqs[eq].status != eq_status::is_dead; } bool is_alive(unsigned eq) const { return m_eqs[eq].status != eq_status::is_dead; }
justification justify_rewrite(unsigned eq1, unsigned eq2); justification justify_rewrite(unsigned eq1, unsigned eq2);
@ -279,6 +286,10 @@ namespace euf {
ac_plugin(egraph& g, func_decl* f); ac_plugin(egraph& g, func_decl* f);
void set_injective() { m_is_injective = true; } void set_injective() { m_is_injective = true; }
void add_unit(enode*);
void add_zero(enode*);
theory_id get_id() const override { return m_fid; } theory_id get_id() const override { return m_fid; }
@ -294,6 +305,8 @@ namespace euf {
std::ostream& display(std::ostream& out) const override; std::ostream& display(std::ostream& out) const override;
void collect_statistics(statistics& st) const override;
void set_undo(std::function<void(void)> u) { m_undo_notify = u; } void set_undo(std::function<void(void)> u) { m_undo_notify = u; }
struct eq_pp { struct eq_pp {

View file

@ -31,6 +31,25 @@ namespace euf {
std::function<void(void)> umul = [&]() { m_undo.push_back(undo_t::undo_mul); }; std::function<void(void)> umul = [&]() { m_undo.push_back(undo_t::undo_mul); };
m_mul.set_undo(umul); m_mul.set_undo(umul);
m_add.set_injective(); m_add.set_injective();
auto e = a.mk_int(0);
auto n = g.find(e) ? g.find(e) : g.mk(e, 0, 0, nullptr);
m_add.add_unit(n);
m_mul.add_zero(n);
e = a.mk_real(0);
n = g.find(e) ? g.find(e) : g.mk(e, 0, 0, nullptr);
m_add.add_unit(n);
m_mul.add_zero(n);
e = a.mk_int(1);
n = g.find(e) ? g.find(e) : g.mk(e, 0, 0, nullptr);
m_mul.add_unit(n);
e = a.mk_real(1);
n = g.find(e) ? g.find(e) : g.mk(e, 0, 0, nullptr);
m_mul.add_unit(n);
} }
void arith_plugin::register_node(enode* n) { void arith_plugin::register_node(enode* n) {

View file

@ -46,6 +46,11 @@ namespace euf {
void propagate() override; void propagate() override;
std::ostream& display(std::ostream& out) const override; std::ostream& display(std::ostream& out) const override;
void collect_statistics(statistics& st) const override {
m_add.collect_statistics(st);
m_mul.collect_statistics(st);
}
}; };
} }

View file

@ -117,6 +117,7 @@ namespace euf {
enode* egraph::mk(expr* f, unsigned generation, unsigned num_args, enode *const* args) { enode* egraph::mk(expr* f, unsigned generation, unsigned num_args, enode *const* args) {
SASSERT(!find(f)); SASSERT(!find(f));
TRACE(euf, tout << "mk: " << mk_bounded_pp(f, m) << " generation: " << generation << " num_args: " << num_args << "\n";);
force_push(); force_push();
enode *n = mk_enode(f, generation, num_args, args); enode *n = mk_enode(f, generation, num_args, args);
@ -157,6 +158,21 @@ namespace euf {
} }
void egraph::propagate_plugins() { void egraph::propagate_plugins() {
if (m_plugins.empty())
return;
if (m_plugin_qhead < m_new_th_eqs.size())
m_updates.push_back(update_record(m_plugin_qhead, update_record::plugin_qhead()));
for (; m_plugin_qhead < m_new_th_eqs.size(); ++m_plugin_qhead) {
auto const& eq = m_new_th_eqs[m_plugin_qhead];
auto* p = get_plugin(eq.id());
if (!p)
continue;
if (eq.is_eq())
p->merge_eh(eq.child(), eq.root());
else
p->diseq_eh(eq.eq());
}
for (auto* p : m_plugins) for (auto* p : m_plugins)
if (p) if (p)
p->propagate(); p->propagate();
@ -167,23 +183,18 @@ namespace euf {
m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r)); m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r));
m_updates.push_back(update_record(update_record::new_th_eq())); m_updates.push_back(update_record(update_record::new_th_eq()));
++m_stats.m_num_th_eqs; ++m_stats.m_num_th_eqs;
auto* p = get_plugin(id);
if (p)
p->merge_eh(c, r);
} }
void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq) { void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq) {
if (!th_propagates_diseqs(id)) if (!th_propagates_diseqs(id))
return; return;
TRACE(euf_verbose, tout << "eq: " << v1 << " != " << v2 << "\n";); TRACE(euf_verbose, tout << "eq: " << v1 << " != " << v2 << "\n";);
m_new_th_eqs.push_back(th_eq(id, v1, v2, eq->get_expr())); m_new_th_eqs.push_back(th_eq(id, v1, v2, eq));
m_updates.push_back(update_record(update_record::new_th_eq())); m_updates.push_back(update_record(update_record::new_th_eq()));
auto* p = get_plugin(id);
if (p)
p->diseq_eh(eq);
++m_stats.m_num_th_diseqs; ++m_stats.m_num_th_diseqs;
} }
void egraph::add_literal(enode* n, enode* ante) { void egraph::add_literal(enode* n, enode* ante) {
TRACE(euf, tout << "propagate " << bpp(n) << " " << bpp(ante) << "\n"); TRACE(euf, tout << "propagate " << bpp(n) << " " << bpp(ante) << "\n");
if (!m_on_propagate_literal) if (!m_on_propagate_literal)
@ -447,6 +458,9 @@ namespace euf {
case update_record::tag_t::is_new_th_eq_qhead: case update_record::tag_t::is_new_th_eq_qhead:
m_new_th_eqs_qhead = p.qhead; m_new_th_eqs_qhead = p.qhead;
break; break;
case update_record::tag_t::is_plugin_qhead:
m_plugin_qhead = p.qhead;
break;
case update_record::tag_t::is_inconsistent: case update_record::tag_t::is_inconsistent:
m_inconsistent = p.m_inconsistent; m_inconsistent = p.m_inconsistent;
break; break;
@ -546,16 +560,18 @@ namespace euf {
void egraph::remove_parents(enode* r) { void egraph::remove_parents(enode* r) {
TRACE(euf_verbose, tout << bpp(r) << "\n"); TRACE(euf_verbose, tout << bpp(r) << "\n");
SASSERT(all_of(enode_parents(r), [&](enode* p) { return !p->is_marked1(); })); SASSERT(all_of(enode_parents(r), [&](enode* p) { return !p->is_marked1(); }));
TRACE(euf, tout << "remove_parents " << bpp(r) << "\n");
for (enode* p : enode_parents(r)) { for (enode* p : enode_parents(r)) {
if (p->is_marked1()) if (p->is_marked1())
continue; continue;
if (p->cgc_enabled()) { if (p->cgc_enabled()) {
if (!p->is_cgr()) if (!p->is_cgr())
continue; continue;
TRACE(euf, tout << "removing " << m_table.contains_ptr(p) << " " << bpp(p) << "\n");
SASSERT(m_table.contains_ptr(p)); SASSERT(m_table.contains_ptr(p));
p->mark1(); p->mark1();
erase_from_table(p); erase_from_table(p);
CTRACE(euf_verbose, m_table.contains_ptr(p), tout << bpp(p) << "\n"; display(tout)); CTRACE(euf, m_table.contains_ptr(p), tout << bpp(p) << "\n"; display(tout));
SASSERT(!m_table.contains_ptr(p)); SASSERT(!m_table.contains_ptr(p));
} }
else if (p->is_equality()) else if (p->is_equality())
@ -564,15 +580,16 @@ namespace euf {
} }
void egraph::reinsert_parents(enode* r1, enode* r2) { void egraph::reinsert_parents(enode* r1, enode* r2) {
TRACE(euf, tout << "reinsert_parents " << bpp(r1) << " " << bpp(r2) << "\n";);
for (enode* p : enode_parents(r1)) { for (enode* p : enode_parents(r1)) {
if (!p->is_marked1()) if (!p->is_marked1())
continue; continue;
p->unmark1(); p->unmark1();
TRACE(euf_verbose, tout << "reinsert " << bpp(r1) << " " << bpp(r2) << " " << bpp(p) << " " << p->cgc_enabled() << "\n";); TRACE(euf, tout << "reinsert " << bpp(r1) << " " << bpp(r2) << " " << bpp(p) << " " << p->cgc_enabled() << "\n";);
if (p->cgc_enabled()) { if (p->cgc_enabled()) {
auto [p_other, comm] = insert_table(p); auto [p_other, comm] = insert_table(p);
SASSERT(m_table.contains_ptr(p) == (p_other == p)); SASSERT(m_table.contains_ptr(p) == (p_other == p));
CTRACE(euf_verbose, p_other != p, tout << "reinsert " << bpp(p) << " == " << bpp(p_other) << " " << p->value() << " " << p_other->value() << "\n"); CTRACE(euf, p_other != p, tout << "reinsert " << bpp(p) << " == " << bpp(p_other) << " " << p->value() << " " << p_other->value() << "\n");
if (p_other != p) if (p_other != p)
m_to_merge.push_back(to_merge(p_other, p, comm)); m_to_merge.push_back(to_merge(p_other, p, comm));
else else
@ -957,6 +974,9 @@ namespace euf {
st.update("euf propagations theory eqs", m_stats.m_num_th_eqs); st.update("euf propagations theory eqs", m_stats.m_num_th_eqs);
st.update("euf propagations theory diseqs", m_stats.m_num_th_diseqs); st.update("euf propagations theory diseqs", m_stats.m_num_th_diseqs);
st.update("euf propagations literal", m_stats.m_num_lits); st.update("euf propagations literal", m_stats.m_num_lits);
for (auto p : m_plugins)
if (p)
p->collect_statistics(st);
} }
void egraph::copy_from(egraph const& src, std::function<void*(void*)>& copy_justification) { void egraph::copy_from(egraph const& src, std::function<void*(void*)>& copy_justification) {

View file

@ -58,7 +58,7 @@ namespace euf {
theory_var m_v2; theory_var m_v2;
union { union {
enode* m_child; enode* m_child;
expr* m_eq; enode* m_eq;
}; };
enode* m_root; enode* m_root;
public: public:
@ -68,10 +68,10 @@ namespace euf {
theory_var v2() const { return m_v2; } theory_var v2() const { return m_v2; }
enode* child() const { SASSERT(is_eq()); return m_child; } enode* child() const { SASSERT(is_eq()); return m_child; }
enode* root() const { SASSERT(is_eq()); return m_root; } enode* root() const { SASSERT(is_eq()); return m_root; }
expr* eq() const { SASSERT(!is_eq()); return m_eq; } enode* eq() const { SASSERT(!is_eq()); return m_eq; }
th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) : th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) :
m_id(id), m_v1(v1), m_v2(v2), m_child(c), m_root(r) {} m_id(id), m_v1(v1), m_v2(v2), m_child(c), m_root(r) {}
th_eq(theory_id id, theory_var v1, theory_var v2, expr* eq) : th_eq(theory_id id, theory_var v1, theory_var v2, enode* eq) :
m_id(id), m_v1(v1), m_v2(v2), m_eq(eq), m_root(nullptr) {} m_id(id), m_v1(v1), m_v2(v2), m_eq(eq), m_root(nullptr) {}
}; };
@ -116,6 +116,7 @@ namespace euf {
struct replace_th_var {}; struct replace_th_var {};
struct new_th_eq {}; struct new_th_eq {};
struct new_th_eq_qhead {}; struct new_th_eq_qhead {};
struct plugin_qhead {};
struct inconsistent {}; struct inconsistent {};
struct value_assignment {}; struct value_assignment {};
struct lbl_hash {}; struct lbl_hash {};
@ -125,7 +126,7 @@ namespace euf {
struct plugin_undo {}; struct plugin_undo {};
enum class tag_t { is_set_parent, is_add_node, is_toggle_cgc, is_toggle_merge_tf, is_update_children, enum class tag_t { is_set_parent, is_add_node, is_toggle_cgc, is_toggle_merge_tf, is_update_children,
is_add_th_var, is_replace_th_var, is_new_th_eq, is_add_th_var, is_replace_th_var, is_new_th_eq,
is_lbl_hash, is_new_th_eq_qhead, is_lbl_hash, is_new_th_eq_qhead, is_plugin_qhead,
is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant, is_inconsistent, is_value_assignment, is_lbl_set, is_set_relevant,
is_plugin_undo }; is_plugin_undo };
tag_t tag; tag_t tag;
@ -158,6 +159,8 @@ namespace euf {
tag(tag_t::is_new_th_eq), r1(nullptr), n1(nullptr), r2_num_parents(0) {} tag(tag_t::is_new_th_eq), r1(nullptr), n1(nullptr), r2_num_parents(0) {}
update_record(unsigned qh, new_th_eq_qhead): update_record(unsigned qh, new_th_eq_qhead):
tag(tag_t::is_new_th_eq_qhead), r1(nullptr), n1(nullptr), qhead(qh) {} tag(tag_t::is_new_th_eq_qhead), r1(nullptr), n1(nullptr), qhead(qh) {}
update_record(unsigned qh, plugin_qhead) :
tag(tag_t::is_plugin_qhead), r1(nullptr), n1(nullptr), qhead(qh) {}
update_record(bool inc, inconsistent) : update_record(bool inc, inconsistent) :
tag(tag_t::is_inconsistent), r1(nullptr), n1(nullptr), m_inconsistent(inc) {} tag(tag_t::is_inconsistent), r1(nullptr), n1(nullptr), m_inconsistent(inc) {}
update_record(enode* n, value_assignment) : update_record(enode* n, value_assignment) :
@ -196,6 +199,7 @@ namespace euf {
enode *m_n2 = nullptr; enode *m_n2 = nullptr;
justification m_justification; justification m_justification;
unsigned m_new_th_eqs_qhead = 0; unsigned m_new_th_eqs_qhead = 0;
unsigned m_plugin_qhead = 0;
svector<th_eq> m_new_th_eqs; svector<th_eq> m_new_th_eqs;
bool_vector m_th_propagates_diseqs; bool_vector m_th_propagates_diseqs;
enode_vector m_todo; enode_vector m_todo;

View file

@ -35,7 +35,7 @@ namespace euf {
void plugin::push_merge(enode* a, enode* b) { void plugin::push_merge(enode* a, enode* b) {
if (a->get_root() == b->get_root()) if (a->get_root() == b->get_root())
return; // already merged return; // already merged
TRACE(plugin, tout << g.bpp(a) << " == " << g.bpp(b) << "\n"); TRACE(plugin, tout << "push-merge " << g.bpp(a) << " == " << g.bpp(b) << "\n");
g.push_merge(a, b, justification::axiom(get_id())); g.push_merge(a, b, justification::axiom(get_id()));
} }

View file

@ -19,6 +19,7 @@ Author:
#pragma once #pragma once
#include "util/statistics.h"
#include "ast/euf/euf_enode.h" #include "ast/euf/euf_enode.h"
#include "ast/euf/euf_justification.h" #include "ast/euf/euf_justification.h"
@ -53,6 +54,8 @@ namespace euf {
virtual void undo() = 0; virtual void undo() = 0;
virtual std::ostream& display(std::ostream& out) const = 0; virtual std::ostream& display(std::ostream& out) const = 0;
virtual void collect_statistics(statistics& st) const {}
}; };
} }

View file

@ -68,6 +68,8 @@ namespace euf {
m_mam(mam::mk(*this, *this)), m_mam(mam::mk(*this, *this)),
m_canonical(m), m_canonical(m),
m_eargs(m), m_eargs(m),
m_expr_trail(m),
m_consequences(m),
m_canonical_proofs(m), m_canonical_proofs(m),
// m_infer_patterns(m, m_smt_params), // m_infer_patterns(m, m_smt_params),
m_deps(m), m_deps(m),
@ -135,6 +137,7 @@ namespace euf {
}; };
m_matcher.set_on_match(on_match); m_matcher.set_on_match(on_match);
} }
completion::~completion() { completion::~completion() {
@ -230,15 +233,51 @@ namespace euf {
read_egraph(); read_egraph();
IF_VERBOSE(1, verbose_stream() << "(euf.completion :rounds " << rounds << " :instances " << m_stats.m_num_instances << " :stop " << should_stop() << ")\n"); IF_VERBOSE(1, verbose_stream() << "(euf.completion :rounds " << rounds << " :instances " << m_stats.m_num_instances << " :stop " << should_stop() << ")\n");
} }
map_congruences();
for (auto c : m_consequences)
add_consequence(c);
TRACE(euf_completion, m_egraph.display(tout));
}
void completion::map_congruences() {
unsigned sz = qtail();
for (unsigned i = qhead(); i < sz; ++i) {
auto [f, p, d] = m_fmls[i]();
if (is_app(f) && to_app(f)->get_num_args() == 1 && symbol("congruences") == to_app(f)->get_decl()->get_name())
map_congruence(to_app(f)->get_arg(0));
}
}
void completion::map_congruence(expr* t) {
auto n = m_egraph.find(t);
if (!n)
return;
ptr_vector<expr> args;
for (auto s : enode_class(n)) {
expr_ref r(s->get_expr(), m);
m_rewriter(r);
args.push_back(r);
}
expr_ref cong(m);
cong = m.mk_app(symbol("congruence"), args.size(), args.data(), m.mk_bool_sort());
m_fmls.add(dependent_expr(m, cong, nullptr, nullptr));
}
void completion::add_consequence(expr* f) {
expr_ref r(f, m);
m_rewriter(r);
f = r.get();
// verbose_stream() << r << "\n";
auto cons = m.mk_app(symbol("consequence"), 1, &f, m.mk_bool_sort());
m_fmls.add(dependent_expr(m, cons, nullptr, nullptr));
} }
void completion::add_egraph() { void completion::add_egraph() {
m_nodes_to_canonize.reset(); m_nodes_to_canonize.reset();
unsigned sz = qtail(); unsigned sz = qtail();
for (unsigned i = qhead(); i < sz; ++i) { for (unsigned i = qhead(); i < sz; ++i) {
auto [f, p, d] = m_fmls[i](); auto [f, p, d] = m_fmls[i]();
add_constraint(f, p, d); add_constraint(f, p, d);
} }
m_should_propagate = true; m_should_propagate = true;
@ -248,6 +287,7 @@ namespace euf {
m_mam->propagate(); m_mam->propagate();
flush_binding_queue(); flush_binding_queue();
propagate_rules(); propagate_rules();
propagate_closures();
IF_VERBOSE(11, verbose_stream() << "propagate " << m_stats.m_num_instances << "\n"); IF_VERBOSE(11, verbose_stream() << "propagate " << m_stats.m_num_instances << "\n");
if (!m_should_propagate && !should_stop()) if (!m_should_propagate && !should_stop())
propagate_all_rules(); propagate_all_rules();
@ -271,7 +311,7 @@ namespace euf {
for (auto* ch : enode_args(n)) for (auto* ch : enode_args(n))
m_nodes_to_canonize.push_back(ch); m_nodes_to_canonize.push_back(ch);
}; };
expr* x, * y; expr* x = nullptr, * y = nullptr;
if (m.is_eq(f, x, y)) { if (m.is_eq(f, x, y)) {
expr_ref x1(x, m); expr_ref x1(x, m);
expr_ref y1(y, m); expr_ref y1(y, m);
@ -285,16 +325,20 @@ namespace euf {
if (a->get_root() == b->get_root()) if (a->get_root() == b->get_root())
return; return;
m_egraph.merge(a, b, to_ptr(push_pr_dep(pr, d))); m_egraph.merge(a, b, to_ptr(push_pr_dep(pr, d)));
m_egraph.propagate();
add_children(a); add_children(a);
add_children(b); add_children(b);
auto a1 = mk_enode(x); auto a1 = mk_enode(x);
if (a1->get_root() != a->get_root()) { if (a1->get_root() != a->get_root()) {
m_egraph.merge(a, a1, nullptr); m_egraph.merge(a, a1, nullptr);
m_egraph.propagate();
add_children(a1); add_children(a1);
} }
auto b1 = mk_enode(y); auto b1 = mk_enode(y);
if (b1->get_root() != b->get_root()) { if (b1->get_root() != b->get_root()) {
TRACE(euf, tout << "merge and propagate\n");
m_egraph.merge(b, b1, nullptr); m_egraph.merge(b, b1, nullptr);
m_egraph.propagate();
add_children(b1); add_children(b1);
} }
@ -310,6 +354,7 @@ namespace euf {
add_quantifiers(f); add_quantifiers(f);
auto j = to_ptr(push_pr_dep(pr, d)); auto j = to_ptr(push_pr_dep(pr, d));
m_egraph.new_diseq(n, j); m_egraph.new_diseq(n, j);
m_egraph.propagate();
add_children(n); add_children(n);
m_should_propagate = true; m_should_propagate = true;
if (m_side_condition_solver) if (m_side_condition_solver)
@ -322,6 +367,7 @@ namespace euf {
return; return;
IF_VERBOSE(1, verbose_stream() << "fml: " << mk_pp(f, m) << "\n"); IF_VERBOSE(1, verbose_stream() << "fml: " << mk_pp(f, m) << "\n");
m_egraph.merge(n, m_tt, to_ptr(push_pr_dep(pr, d))); m_egraph.merge(n, m_tt, to_ptr(push_pr_dep(pr, d)));
m_egraph.propagate();
add_children(n); add_children(n);
if (is_forall(f)) { if (is_forall(f)) {
quantifier* q = to_quantifier(f); quantifier* q = to_quantifier(f);
@ -352,7 +398,7 @@ namespace euf {
} }
add_rule(f, pr, d); add_rule(f, pr, d);
if (!is_forall(f) && !m.is_implies(f)) { if (!is_forall(f) && !m.is_implies(f) && !m.is_or(f)) {
add_quantifiers(f); add_quantifiers(f);
if (m_side_condition_solver) if (m_side_condition_solver)
m_side_condition_solver->add_constraint(f, pr, d); m_side_condition_solver->add_constraint(f, pr, d);
@ -388,18 +434,27 @@ namespace euf {
else if (is_quantifier(t)) { else if (is_quantifier(t)) {
auto q = to_quantifier(t); auto q = to_quantifier(t);
auto nd = q->get_num_decls(); auto nd = q->get_num_decls();
verbose_stream() << "bind " << mk_pp(q, m) << "\n"; IF_VERBOSE(1, verbose_stream() << "bind " << mk_pp(q, m) << "\n");
for (unsigned i = 0; i < nd; ++i) { for (unsigned i = 0; i < nd; ++i) {
auto name = std::string("bound!") + std::to_string(bound.size()); auto name = std::string("bound!") + std::to_string(bound.size());
auto b = m.mk_const(name, q->get_decl_sort(i)); auto b = m.mk_const(name, q->get_decl_sort(i));
// TODO: persist bound variables withn scope to avoid reference count crashes if (b->get_ref_count() == 0) {
m_expr_trail.push_back(b);
get_trail().push(push_back_vector(m_expr_trail));
}
bound.push_back(b); bound.push_back(b);
} }
expr_ref inst = var_subst(m)(q->get_expr(), bound); expr_ref inst = var_subst(m)(q->get_expr(), bound);
if (!m_egraph.find(inst)) { if (!m_egraph.find(inst)) {
expr_ref clos(m);
m_closures.insert(q, { bound, inst }); m_closures.insert(q, { bound, inst });
get_trail().push(insert_map(m_closures, q)); get_trail().push(insert_map(m_closures, q));
mk_enode(inst); // ensure that inst occurs in a foreign context to enable equality propagation
// on inst.
func_decl* f = m.mk_func_decl(symbol("clos!"), inst->get_sort(), m.mk_bool_sort());
clos = m.mk_app(f, inst);
mk_enode(clos);
// TODO: handle nested quantifiers after m_closures is updated to // TODO: handle nested quantifiers after m_closures is updated to
// index on sort declaration prefix together with quantifier // index on sort declaration prefix together with quantifier
// add_quantifiers(bound, inst); // add_quantifiers(bound, inst);
@ -445,13 +500,31 @@ namespace euf {
void completion::add_rule(expr* f, proof* pr, expr_dependency* d) { void completion::add_rule(expr* f, proof* pr, expr_dependency* d) {
expr* x = nullptr, * y = nullptr; expr* x = nullptr, * y = nullptr;
if (!m.is_implies(f, x, y))
return;
expr_ref_vector body(m); expr_ref_vector body(m);
proof_ref pr_i(m), pr0(m); proof_ref pr_i(m), pr0(m);
expr_ref_vector prs(m); expr_ref_vector prs(m);
expr_ref head(y, m); expr_ref head(m);
body.push_back(x); if (m.is_implies(f, x, y)) {
head = y;
body.push_back(x);
}
else if (m.is_or(f)) {
auto a = to_app(f);
for (auto arg : *to_app(f)) {
if (m.is_eq(arg)) {
if (head)
return;
head = arg;
}
else
body.push_back(arg);
}
if (!head)
return;
}
else
return;
flatten_and(body); flatten_and(body);
unsigned j = 0; unsigned j = 0;
flet<bool> _propagate_with_solver(m_propagate_with_solver, true); flet<bool> _propagate_with_solver(m_propagate_with_solver, true);
@ -552,6 +625,39 @@ namespace euf {
} }
} }
void completion::propagate_closures() {
for (auto [q, clos] : m_closures) {
expr* body = clos.second;
auto n = m_egraph.find(body);
SASSERT(n);
#if 0
verbose_stream() << "class of " << mk_pp(body, m) << "\n";
for (auto s : euf::enode_class(n)) {
verbose_stream() << mk_pp(s->get_expr(), m) << "\n";
}
#endif
if (n->is_root())
continue;
auto qn = m_egraph.find(q);
#if 0
verbose_stream() << "class of " << mk_pp(q, m) << "\n";
for (auto s : euf::enode_class(qn)) {
verbose_stream() << mk_pp(s->get_expr(), m) << "\n";
}
#endif
expr_ref new_body = expr_ref(n->get_root()->get_expr(), m);
expr_ref new_q = expr_abstract(m, clos.first, new_body);
new_q = m.update_quantifier(q, new_q);
auto new_qn = m_egraph.find(new_q);
if (!new_qn)
new_qn = m_egraph.mk(new_q, qn->generation(), 0, nullptr);
if (new_qn->get_root() == qn->get_root())
continue;
m_egraph.merge(new_qn, qn, nullptr); // todo track dependencies
m_should_propagate = true;
}
}
binding* completion::tmp_binding(quantifier* q, app* pat, euf::enode* const* _binding) { binding* completion::tmp_binding(quantifier* q, app* pat, euf::enode* const* _binding) {
if (q->get_num_decls() > m_tmp_binding_capacity) { if (q->get_num_decls() > m_tmp_binding_capacity) {
void* mem = memory::allocate(sizeof(binding) + q->get_num_decls() * sizeof(euf::enode*)); void* mem = memory::allocate(sizeof(binding) + q->get_num_decls() * sizeof(euf::enode*));
@ -643,11 +749,12 @@ namespace euf {
void completion::apply_binding(binding& b, quantifier* q, expr_ref_vector const& s) { void completion::apply_binding(binding& b, quantifier* q, expr_ref_vector const& s) {
var_subst subst(m); var_subst subst(m);
expr_ref r = subst(q->get_expr(), s); expr_ref r = subst(q->get_expr(), s);
scoped_generation sg(*this, b.m_max_top_generation + 1); scoped_generation sg(*this, b.m_max_top_generation + 1);
auto [pr, d] = get_dependency(q); auto [pr, d] = get_dependency(q);
if (pr) if (pr)
pr = m.mk_quant_inst(m.mk_or(m.mk_not(q), r), s.size(), s.data()); pr = m.mk_quant_inst(m.mk_or(m.mk_not(q), r), s.size(), s.data());
m_consequences.push_back(r);
add_constraint(r, pr, d); add_constraint(r, pr, d);
propagate_rules(); propagate_rules();
m_egraph.propagate(); m_egraph.propagate();
@ -788,7 +895,7 @@ namespace euf {
if (x1 == y1) if (x1 == y1)
r = expr_ref(m.mk_true(), m); r = expr_ref(m.mk_true(), m);
else { else {
expr* c = get_canonical(x, pr3, d); auto c = get_canonical(x, pr3, d);
if (c == x1) if (c == x1)
r = m_rewriter.mk_eq(y1, c); r = m_rewriter.mk_eq(y1, c);
else if (c == y1) else if (c == y1)
@ -832,8 +939,6 @@ namespace euf {
} }
expr_ref completion::canonize(expr* f, proof_ref& pr, expr_dependency_ref& d) { expr_ref completion::canonize(expr* f, proof_ref& pr, expr_dependency_ref& d) {
if (is_quantifier(f))
return expr_ref(canonize(to_quantifier(f), pr, d), m);
if (!is_app(f)) if (!is_app(f))
return expr_ref(f, m); // todo could normalize ground expressions under quantifiers return expr_ref(f, m); // todo could normalize ground expressions under quantifiers
@ -862,13 +967,30 @@ namespace euf {
return r; return r;
} }
expr_ref completion::canonize(quantifier* q, proof_ref& pr, expr_dependency_ref& d) { expr_ref completion::get_canonical(quantifier* q, proof_ref& pr, expr_dependency_ref& d) {
std::pair<ptr_vector<expr>, expr*> clos; std::pair<ptr_vector<expr>, expr*> clos;
// verbose_stream() << "canonize " << mk_pp(q, m) << "\n";
if (!m_closures.find(q, clos)) if (!m_closures.find(q, clos))
return expr_ref(q, m); return expr_ref(q, m);
expr* body = clos.second; expr* body = clos.second;
expr_ref new_body = canonize(body, pr, d); auto n = m_egraph.find(body);
SASSERT(n);
#if 0
verbose_stream() << "class of " << mk_pp(body, m) << "\n";
for (auto s : euf::enode_class(n)) {
verbose_stream() << mk_pp(s->get_expr(), m) << "\n";
}
#endif
n = m_egraph.find(q);
#if 0
verbose_stream() << "class of " << mk_pp(q, m) << "\n";
for (auto s : euf::enode_class(n)) {
verbose_stream() << mk_pp(s->get_expr(), m) << "\n";
}
#endif
expr_ref new_body = get_canonical(body, pr, d);
expr_ref result = expr_abstract(m, clos.first, new_body); expr_ref result = expr_abstract(m, clos.first, new_body);
result = m.update_quantifier(q, result);
if (m.proofs_enabled()) { if (m.proofs_enabled()) {
// add proof rule // add proof rule
// //
@ -881,10 +1003,28 @@ namespace euf {
} }
expr* completion::get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d) { expr_ref completion::get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d) {
expr_ref e(m);
if (has_quantifiers(f)) {
if (is_quantifier(f))
return get_canonical(to_quantifier(f), pr, d);
else if (is_app(f)) {
expr_ref_vector args(m);
for (auto arg : *to_app(f)) {
// TODO: pr reconstruction
args.push_back(get_canonical(arg, pr, d));
}
e = m.mk_app(to_app(f)->get_decl(), args);
if (!m_egraph.find(e))
return e;
f = e;
}
else
UNREACHABLE();
}
enode* n = m_egraph.find(f); enode* n = m_egraph.find(f);
if (!n) verbose_stream() << "not found " << f->get_id() << " " << mk_pp(f, m) << "\n";
enode* r = n->get_root(); enode* r = n->get_root();
d = m.mk_join(d, explain_eq(n, r)); d = m.mk_join(d, explain_eq(n, r));
d = m.mk_join(d, m_deps.get(r->get_id(), nullptr)); d = m.mk_join(d, m_deps.get(r->get_id(), nullptr));
@ -894,7 +1034,7 @@ namespace euf {
pr = m.mk_transitivity(pr, get_canonical_proof(r)); pr = m.mk_transitivity(pr, get_canonical_proof(r));
} }
SASSERT(m_canonical.get(r->get_id())); SASSERT(m_canonical.get(r->get_id()));
return m_canonical.get(r->get_id()); return expr_ref(m_canonical.get(r->get_id()), m);
} }
expr* completion::get_canonical(enode* n) { expr* completion::get_canonical(enode* n) {
@ -990,6 +1130,7 @@ namespace euf {
void completion::collect_statistics(statistics& st) const { void completion::collect_statistics(statistics& st) const {
st.update("euf-completion-rewrites", m_stats.m_num_rewrites); st.update("euf-completion-rewrites", m_stats.m_num_rewrites);
st.update("euf-completion-instances", m_stats.m_num_instances); st.update("euf-completion-instances", m_stats.m_num_instances);
m_egraph.collect_statistics(st);
} }
bool completion::is_gt(expr* lhs, expr* rhs) const { bool completion::is_gt(expr* lhs, expr* rhs) const {
@ -1098,8 +1239,8 @@ namespace euf {
proof_ref pr(m); proof_ref pr(m);
prs.reset(); prs.reset();
for (enode* arg : enode_args(rep)) { for (enode* arg : enode_args(rep)) {
enode* rarg = arg->get_root(); auto rarg = arg->get_root();
expr* c = get_canonical(rarg); auto c = get_canonical(rarg);
if (c) { if (c) {
m_eargs.push_back(c); m_eargs.push_back(c);
new_arg |= c != arg->get_expr(); new_arg |= c != arg->get_expr();

View file

@ -128,7 +128,7 @@ namespace euf {
enode* m_tt, *m_ff; enode* m_tt, *m_ff;
ptr_vector<expr> m_todo; ptr_vector<expr> m_todo;
enode_vector m_args, m_reps, m_nodes_to_canonize; enode_vector m_args, m_reps, m_nodes_to_canonize;
expr_ref_vector m_canonical, m_eargs; expr_ref_vector m_canonical, m_eargs, m_expr_trail, m_consequences;
proof_ref_vector m_canonical_proofs; proof_ref_vector m_canonical_proofs;
// pattern_inference_rw m_infer_patterns; // pattern_inference_rw m_infer_patterns;
bindings m_bindings; bindings m_bindings;
@ -166,11 +166,14 @@ namespace euf {
void read_egraph(); void read_egraph();
expr_ref canonize(expr* f, proof_ref& pr, expr_dependency_ref& dep); expr_ref canonize(expr* f, proof_ref& pr, expr_dependency_ref& dep);
expr_ref canonize_fml(expr* f, proof_ref& pr, expr_dependency_ref& dep); expr_ref canonize_fml(expr* f, proof_ref& pr, expr_dependency_ref& dep);
expr* get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d); expr_ref get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d);
expr* get_canonical(enode* n); expr* get_canonical(enode* n);
proof* get_canonical_proof(enode* n); proof* get_canonical_proof(enode* n);
void set_canonical(enode* n, expr* e, proof* pr); void set_canonical(enode* n, expr* e, proof* pr);
void add_constraint(expr*f, proof* pr, expr_dependency* d); void add_constraint(expr*f, proof* pr, expr_dependency* d);
void map_congruences();
void map_congruence(expr* t);
void add_consequence(expr* t);
// Enable equality propagation inside of quantifiers // Enable equality propagation inside of quantifiers
// add quantifier bodies as closure terms to the E-graph. // add quantifier bodies as closure terms to the E-graph.
@ -181,7 +184,7 @@ namespace euf {
// Closure terms are re-abstracted by the canonizer. // Closure terms are re-abstracted by the canonizer.
void add_quantifiers(ptr_vector<expr>& bound, expr* t); void add_quantifiers(ptr_vector<expr>& bound, expr* t);
void add_quantifiers(expr* t); void add_quantifiers(expr* t);
expr_ref canonize(quantifier* q, proof_ref& pr, expr_dependency_ref& d); expr_ref get_canonical(quantifier* q, proof_ref& pr, expr_dependency_ref& d);
obj_map<quantifier, std::pair<ptr_vector<expr>, expr*>> m_closures; obj_map<quantifier, std::pair<ptr_vector<expr>, expr*>> m_closures;
expr_dependency* explain_eq(enode* a, enode* b); expr_dependency* explain_eq(enode* a, enode* b);
@ -208,6 +211,7 @@ namespace euf {
void propagate_rule(conditional_rule& r); void propagate_rule(conditional_rule& r);
void propagate_rules(); void propagate_rules();
void propagate_all_rules(); void propagate_all_rules();
void propagate_closures();
void clear_propagation_queue(); void clear_propagation_queue();
ptr_vector<conditional_rule> m_propagation_queue; ptr_vector<conditional_rule> m_propagation_queue;
struct push_watch_rule; struct push_watch_rule;

View file

@ -572,7 +572,7 @@ namespace arith {
} }
void solver::new_diseq_eh(euf::th_eq const& e) { void solver::new_diseq_eh(euf::th_eq const& e) {
TRACE(artih, tout << mk_bounded_pp(e.eq(), m) << "\n"); TRACE(artih, tout << mk_bounded_pp(e.eq()->get_expr(), m) << "\n");
ensure_column(e.v1()); ensure_column(e.v1());
ensure_column(e.v2()); ensure_column(e.v2());
m_delayed_eqs.push_back(std::make_pair(e, false)); m_delayed_eqs.push_back(std::make_pair(e, false));

View file

@ -1147,7 +1147,7 @@ namespace arith {
new_eq_eh(e); new_eq_eh(e);
else if (is_eq(e.v1(), e.v2())) { else if (is_eq(e.v1(), e.v2())) {
mk_diseq_axiom(e.v1(), e.v2()); mk_diseq_axiom(e.v1(), e.v2());
TRACE(arith, tout << mk_bounded_pp(e.eq(), m) << " " << use_nra_model() << "\n"); TRACE(arith, tout << mk_bounded_pp(e.eq()->get_expr(), m) << " " << use_nra_model() << "\n");
found_diseq = true; found_diseq = true;
break; break;
} }

View file

@ -280,7 +280,7 @@ namespace bv {
undef_idx--; undef_idx--;
sat::literal consequent = m_bits[v1][undef_idx]; sat::literal consequent = m_bits[v1][undef_idx];
sat::literal b = m_bits[v2][undef_idx]; sat::literal b = m_bits[v2][undef_idx];
sat::literal antecedent = ~expr2literal(ne.eq()); sat::literal antecedent = ~expr2literal(ne.eq()->get_expr());
SASSERT(s().value(antecedent) == l_true); SASSERT(s().value(antecedent) == l_true);
SASSERT(s().value(consequent) == l_undef); SASSERT(s().value(consequent) == l_undef);
SASSERT(s().value(b) != l_undef); SASSERT(s().value(b) != l_undef);

View file

@ -758,6 +758,7 @@ X(Global, pivot_bug, "pivot bug")
X(Global, pivot_shape, "pivot shape") X(Global, pivot_shape, "pivot shape")
X(Global, pivot_stats, "pivot stats") X(Global, pivot_stats, "pivot stats")
X(Global, plugin, "plugin") X(Global, plugin, "plugin")
X(Global, plugin_verbose, "plugin verbose")
X(Global, pob_queue, "pob queue") X(Global, pob_queue, "pob queue")
X(Global, poly_rewriter, "poly rewriter") X(Global, poly_rewriter, "poly rewriter")
X(Global, polynomial, "polynomial") X(Global, polynomial, "polynomial")