3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-08-02 01:13:18 +00:00

wip - throttle AC completion, enable congruences over bound bodies

- AC completion which is exposed as an option to the new congruence closure core used roots of E-Graph which gets ordering of monomials out of sync.
- Added injective function handling to AC completion
- Move to model where all equations, also unit to unit are in completion
- throw in first level bound bodies into the E-graph to enable canonization on them.
This commit is contained in:
Nikolaj Bjorner 2025-07-11 12:48:27 +02:00
parent 35b1d09425
commit 0995928f6e
7 changed files with 345 additions and 76 deletions

View file

@ -53,6 +53,14 @@ More notes:
after merge
n' + k = j + k, could be that n' + k < j + k < n + k in term ordering because n' < j, m < n
We filter results of superposition so that the size of monomials in the new equations don't grow.
This filter is used to make the process somewhat tractable. In other words, we never compute a
complete saturation. If l1 = r1 l2 = r2 are used to produce new equation l = r, we ensure
that min(|l|,|r|) <= min(|r1|,|r2|) and max(|l|,|r|) <= max(|l1|,|l2|)
If the operator is injective we simplify equations
xl = xr to l = r
TODOs:
- Efficiency of handling shared terms.
@ -63,6 +71,7 @@ TODOs:
- by an epoch counter that can be updated by the egraph class whenever there is a push/pop.
- store the epoch as a tick on equations and possibly when updating monomials on equations.
--*/
#include "ast/euf/euf_ac_plugin.h"
@ -195,6 +204,20 @@ namespace euf {
return out;
}
std::ostream& ac_plugin::display_monomial_ll(std::ostream& out, ptr_vector<node> const& m) const {
for (auto n : m)
out << n->id() << " ";
return out;
}
std::ostream& ac_plugin::display_equation_ll(std::ostream& out, eq const& e) const {
display_status(out, e.status) << " ";
display_monomial_ll(out, monomial(e.l));
out << "== ";
display_monomial_ll(out, monomial(e.r));
return out;
}
std::ostream& ac_plugin::display_status(std::ostream& out, eq_status s) const {
switch (s) {
case eq_status::is_dead: out << "d"; break;
@ -207,15 +230,14 @@ namespace euf {
std::ostream& ac_plugin::display(std::ostream& out) const {
unsigned i = 0;
for (auto const& eq : m_eqs) {
out << i << ": " << eq.l << " == " << eq.r << ": ";
display_equation(out, eq);
out << "\n";
if (eq.status != eq_status::is_dead)
out << i << ": " << eq_pp_ll(*this, eq) << "\n";
++i;
}
i = 0;
for (auto m : m_monomials) {
out << i << ": ";
display_monomial(out, m);
display_monomial_ll(out, m);
out << "\n";
++i;
}
@ -224,7 +246,7 @@ namespace euf {
continue;
if (n->eqs.empty() && n->shared.empty())
continue;
out << g.bpp(n->n) << " r: " << n->root_id() << " ";
out << g.bpp(n->n) << " r: " << n->id() << " ";
if (!n->eqs.empty()) {
out << "eqs ";
for (auto l : n->eqs)
@ -247,8 +269,7 @@ namespace euf {
TRACE(plugin, tout << g.bpp(l) << " == " << g.bpp(r) << " " << is_op(l) << " " << is_op(r) << "\n");
if (!is_op(l) && !is_op(r))
merge(mk_node(l), mk_node(r), j);
else
init_equation(eq(to_monomial(l), to_monomial(r), j));
init_equation(eq(to_monomial(l), to_monomial(r), j));
}
void ac_plugin::diseq_eh(enode* eq) {
@ -264,12 +285,20 @@ namespace euf {
void ac_plugin::init_equation(eq const& e) {
m_eqs.push_back(e);
auto& eq = m_eqs.back();
TRACE(plugin, display_equation(tout, e) << "\n");
TRACE(plugin, display_equation_ll(tout, e) << "\n");
deduplicate(monomial(eq.l).m_nodes, monomial(eq.r).m_nodes);
TRACE(plugin, display_equation_ll(tout << "dedup ", e) << "\n");
if (orient_equation(eq)) {
auto& ml = monomial(eq.l);
auto& mr = monomial(eq.r);
if (ml.size() == 1 && mr.size() == 1)
push_merge(ml[0]->n, mr[0]->n, eq.j);
unsigned eq_id = m_eqs.size() - 1;
for (auto n : monomial(eq.l)) {
for (auto n : ml) {
if (!n->root->n->is_marked1()) {
n->root->eqs.push_back(eq_id);
n->root->n->mark1();
@ -280,7 +309,7 @@ namespace euf {
}
}
for (auto n : monomial(eq.r)) {
for (auto n : mr) {
if (!n->root->n->is_marked1()) {
n->root->eqs.push_back(eq_id);
n->root->n->mark1();
@ -291,13 +320,13 @@ namespace euf {
}
}
for (auto n : monomial(eq.l))
for (auto n : ml)
n->root->n->unmark1();
for (auto n : monomial(eq.r))
for (auto n : mr)
n->root->n->unmark1();
TRACE(plugin, display_equation(tout, e) << "\n");
TRACE(plugin, display_equation_ll(tout, e) << "\n");
m_to_simplify_todo.insert(eq_id);
}
else
@ -312,14 +341,14 @@ namespace euf {
if (ml.size() < mr.size()) {
std::swap(e.l, e.r);
return true;
}
}
else {
sort(ml);
sort(mr);
for (unsigned i = ml.size(); i-- > 0;) {
if (ml[i]->root_id() == mr[i]->root_id())
for (unsigned i = 0; i < ml.size(); ++i) {
if (ml[i]->id() == mr[i]->id())
continue;
if (ml[i]->root_id() < mr[i]->root_id())
if (ml[i]->id() < mr[i]->id())
std::swap(e.l, e.r);
return true;
}
@ -327,15 +356,38 @@ namespace euf {
}
}
bool ac_plugin::is_equation_oriented(eq const& e) const {
auto& ml = monomial(e.l);
auto& mr = monomial(e.r);
if (ml.size() > mr.size())
return true;
if (ml.size() < mr.size())
return false;
else {
if (!is_sorted(ml))
return false;
if (!is_sorted(mr))
return false;
for (unsigned i = 0; i < ml.size(); ++i) {
if (ml[i]->id() == mr[i]->id())
continue;
if (ml[i]->id() < mr[i]->id())
return false;
return true;
}
return false;
}
}
void ac_plugin::sort(monomial_t& m) {
std::sort(m.begin(), m.end(), [&](node* a, node* b) { return a->root_id() < b->root_id(); });
std::sort(m.begin(), m.end(), [&](node* a, node* b) { return a->id() < b->id(); });
}
bool ac_plugin::is_sorted(monomial_t const& m) const {
if (m.m_bloom.m_tick == m_tick)
return true;
for (unsigned i = m.size(); i-- > 1; )
if (m[i - 1]->root_id() > m[i]->root_id())
if (m[i - 1]->id() > m[i]->id())
return false;
return true;
}
@ -346,7 +398,7 @@ namespace euf {
return bloom.m_filter;
bloom.m_filter = 0;
for (auto n : m)
bloom.m_filter |= (1ull << (n->root_id() % 64ull));
bloom.m_filter |= (1ull << (n->id() % 64ull));
if (!is_sorted(m))
sort(m);
bloom.m_tick = m_tick;
@ -367,25 +419,29 @@ namespace euf {
if (bloom.m_tick != m_tick) {
bloom.m_filter = 0;
for (auto n : m)
bloom.m_filter |= (1ull << (n->root_id() % 64ull));
bloom.m_filter |= (1ull << (n->id() % 64ull));
bloom.m_tick = m_tick;
}
auto f2 = bloom.m_filter;
return (filter(subset) | f2) == f2;
}
void ac_plugin::merge(node* root, node* other, justification j) {
TRACE(plugin, tout << root << " == " << other << " num shared " << other->shared.size() << "\n");
for (auto n : equiv(other))
n->root = root;
m_merge_trail.push_back({ other, root->shared.size(), root->eqs.size() });
for (auto eq_id : other->eqs)
void ac_plugin::merge(node* a, node* b, justification j) {
TRACE(plugin, tout << a << " == " << b << " num shared " << b->shared.size() << "\n");
if (a == b)
return;
if (a->id() < b->id())
std::swap(a, b);
for (auto n : equiv(a))
n->root = b;
m_merge_trail.push_back({ a, b->shared.size(), b->eqs.size() });
for (auto eq_id : a->eqs)
set_status(eq_id, eq_status::to_simplify);
for (auto m : other->shared)
for (auto m : a->shared)
m_shared_todo.insert(m);
root->shared.append(other->shared);
root->eqs.append(other->eqs);
std::swap(root->next, other->next);
b->shared.append(a->shared);
b->eqs.append(a->eqs);
std::swap(b->next, a->next);
push_undo(is_merge_node);
++m_tick;
}
@ -427,7 +483,7 @@ namespace euf {
args.push_back(arg->root->n->get_expr());
}
auto n = m.mk_app(m_fid, m_op, args.size(), args.data());
return g.mk(n, 0, nodes.size(), nodes.data());
return g.find(n) ? g.find(n) : g.mk(n, 0, nodes.size(), nodes.data());
}
ac_plugin::node* ac_plugin::node::mk(region& r, enode* n) {
@ -460,7 +516,7 @@ namespace euf {
if (eq_id == UINT_MAX)
break;
TRACE(plugin, tout << "propagate " << eq_id << ": " << eq_pp(*this, m_eqs[eq_id]) << "\n");
TRACE(plugin, tout << "propagate " << eq_id << ": " << eq_pp_ll(*this, m_eqs[eq_id]) << "\n");
// simplify eq using processed
TRACE(plugin,
@ -470,6 +526,12 @@ namespace euf {
if (is_processed(other_eq) && backward_simplify(eq_id, other_eq))
goto loop_start;
auto& eq = m_eqs[eq_id];
deduplicate(monomial(eq.l).m_nodes, monomial(eq.r).m_nodes);
if (monomial(eq.l).size() == 0) {
set_status(eq_id, eq_status::is_dead);
continue;
}
set_status(eq_id, eq_status::processed);
// simplify processed using eq
@ -478,9 +540,13 @@ namespace euf {
forward_simplify(eq_id, other_eq);
// superpose, create new equations
unsigned new_eqs = 0;
for (auto other_eq : superpose_iterator(eq_id))
if (is_processed(other_eq))
superpose(eq_id, other_eq);
new_eqs += superpose(eq_id, other_eq);
(void)new_eqs;
TRACE(plugin, tout << "added eqs " << new_eqs << "\n");
// simplify to_simplify using eq
for (auto other_eq : forward_iterator(eq_id))
@ -633,7 +699,7 @@ namespace euf {
void ac_plugin::init_ref_counts(ptr_vector<node> const& monomial, ref_counts& counts) const {
counts.reset();
for (auto n : monomial)
counts.inc(n->root_id(), 1);
counts.inc(n->id(), 1);
}
bool ac_plugin::is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const {
@ -660,7 +726,7 @@ namespace euf {
auto& src = m_eqs[src_eq]; // src_r_counts, src_l_counts are initialized
auto& dst = m_eqs[dst_eq];
TRACE(plugin, tout << "forward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << "\n");
TRACE(plugin, tout << "forward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << "\n");
if (forward_subsumes(src_eq, dst_eq)) {
@ -685,7 +751,7 @@ namespace euf {
// := src_rhs + elements from dst_rhs that are in excess of src_lhs
unsigned num_overlap = 0;
for (auto n : monomial(dst.r)) {
unsigned id = n->root_id();
unsigned id = n->id();
unsigned dst_count = m_dst_r_counts[id];
unsigned src_count = m_src_l_counts[id];
if (dst_count > src_count) {
@ -726,7 +792,7 @@ namespace euf {
//
// dst_ids, dst_count contain rhs of dst_eq
//
TRACE(plugin, tout << "backward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n");
TRACE(plugin, tout << "backward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n");
if (backward_subsumes(src_eq, dst_eq)) {
TRACE(plugin, tout << "backward subsumed\n");
@ -782,7 +848,7 @@ namespace euf {
SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts));
// add difference betwen dst.l and src.l to both src.l, src.r
for (auto n : monomial(dst.l)) {
unsigned id = n->root_id();
unsigned id = n->id();
SASSERT(m_dst_l_counts[id] >= m_src_l_counts[id]);
unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id];
if (diff > 0) {
@ -792,7 +858,7 @@ namespace euf {
}
// now dst.r and src.r should align and have the same elements.
// since src.r is a subset of dst.r we iterate over dst.r
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
}
// src_l_counts, src_r_counts are initialized for src.l, src.r
@ -815,7 +881,7 @@ namespace euf {
SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts));
SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts));
for (auto n : monomial(src.l)) {
unsigned id = n->root_id();
unsigned id = n->id();
SASSERT(m_src_l_counts[id] <= m_dst_l_counts[id]);
unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id];
if (diff == 0)
@ -826,7 +892,7 @@ namespace euf {
m_dst_r_counts.dec(id, diff);
}
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
}
void ac_plugin::rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_counts, ptr_vector<node>& dst) {
@ -837,14 +903,17 @@ namespace euf {
SASSERT(is_correct_ref_count(dst, dst_counts));
SASSERT(&src_r.m_nodes != &dst);
unsigned sz = dst.size(), j = 0;
bool change = false;
for (unsigned i = 0; i < sz; ++i) {
auto* n = dst[i];
unsigned id = n->root_id();
unsigned id = n->id();
unsigned dst_count = dst_counts[id];
unsigned src_count = src_l[id];
SASSERT(dst_count > 0);
if (src_count == 0)
dst[j++] = n;
if (src_count == 0) {
dst[j++] = n;
}
else if (src_count < dst_count) {
dst[j++] = n;
dst_counts.dec(id, 1);
@ -857,25 +926,41 @@ namespace euf {
// rewrite monomial to normal form.
bool ac_plugin::reduce(ptr_vector<node>& m, justification& j) {
bool change = false;
unsigned sz = m.size();
unsigned jj = 0;
//verbose_stream() << "start\n";
do {
init_loop:
//verbose_stream() << "loop " << jj++ << "\n";
if (m.size() == 1)
return change;
bloom b;
init_ref_counts(m, m_m_counts);
unsigned k = 0;
for (auto n : m) {
//verbose_stream() << "inner loop " << k++ << "\n";
for (auto eq : n->root->eqs) {
if (!is_processed(eq))
continue;
auto& src = m_eqs[eq];
if (!is_equation_oriented(src)) {
continue;
if (!orient_equation(src))
continue;
// deduplicate(src.l, src.r);
}
if (!can_be_subset(monomial(src.l), m, b))
continue;
if (!is_subset(m_m_counts, m_eq_counts, monomial(src.l)))
continue;
TRACE(plugin, display_equation(tout << "reduce ", src) << "\n");
TRACE(plugin, display_equation_ll(tout << "reduce ", src) << "\n");
SASSERT(is_correct_ref_count(monomial(src.l), m_eq_counts));
//display_equation_ll(std::cout << "reduce ", src) << ": ";
//display_monomial_ll(std::cout, m);
rewrite1(m_eq_counts, monomial(src.r), m_m_counts, m);
//display_monomial_ll(std::cout << " -> ", m) << "\n";
j = join(j, eq);
change = true;
goto init_loop;
@ -883,6 +968,7 @@ namespace euf {
}
}
while (false);
VERIFY(sz >= m.size());
return change;
}
@ -917,13 +1003,17 @@ namespace euf {
}
void ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) {
bool ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) {
if (src_eq == dst_eq)
return;
return false;
auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq];
TRACE(plugin, tout << "superpose: "; display_equation(tout, src); tout << " "; display_equation(tout, dst); tout << "\n";);
unsigned max_left = std::max(monomial(src.l).size(), monomial(dst.l).size());
unsigned min_right = std::max(monomial(src.r).size(), monomial(dst.r).size());
TRACE(plugin, tout << "superpose: "; display_equation_ll(tout, src); tout << " "; display_equation_ll(tout, dst); tout << "\n";);
// AB -> C, AD -> E => BE ~ CD
// m_src_ids, m_src_counts contains information about src (call it AD -> E)
m_dst_l_counts.reset();
@ -938,7 +1028,7 @@ namespace euf {
// compute BE, initialize dst_ids, dst_counts
bool overlap = false;
for (auto n : monomial(dst.l)) {
unsigned id = n->root_id();
unsigned id = n->id();
m_dst_l_counts.inc(id, 1);
if (m_src_l_counts[id] < m_dst_l_counts[id])
m_src_r.push_back(n);
@ -947,12 +1037,12 @@ namespace euf {
if (!overlap) {
m_src_r.shrink(src_r_size);
return;
return false;
}
// compute CD
for (auto n : monomial(src.l)) {
unsigned id = n->root_id();
unsigned id = n->id();
if (m_dst_l_counts[id] > 0)
m_dst_l_counts.dec(id, 1);
else
@ -961,21 +1051,28 @@ namespace euf {
if (are_equal(m_src_r, m_dst_r)) {
m_src_r.shrink(src_r_size);
return;
return false;
}
TRACE(plugin, tout << m_pp(*this, m_src_r) << "== " << m_pp(*this, m_dst_r) << "\n";);
TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";);
justification j = justify_rewrite(src_eq, dst_eq);
deduplicate(m_src_r, m_dst_r);
reduce(m_dst_r, j);
reduce(m_src_r, j);
if (m_src_r.size() == 1 && m_dst_r.size() == 1)
push_merge(m_src_r[0]->n, m_dst_r[0]->n, j);
else
TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";);
bool added_eq = false;
unsigned max_left_new = std::max(m_src_r.size(), m_dst_r.size());
unsigned min_right_new = std::min(m_src_r.size(), m_dst_r.size());
if (max_left_new <= max_left && min_right_new <= min_right) {
init_equation(eq(to_monomial(m_src_r), to_monomial(m_dst_r), j));
added_eq = true;
}
m_src_r.reset();
m_src_r.append(monomial(src.r).m_nodes);
return added_eq;
}
bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) {
@ -987,10 +1084,10 @@ namespace euf {
return false;
m_eq_counts.reset();
for (auto n : a)
m_eq_counts.inc(n->root_id(), 1);
m_eq_counts.inc(n->id(), 1);
for (auto n : b) {
unsigned id = n->root_id();
unsigned id = n->id();
if (m_eq_counts[id] == 0)
return false;
m_eq_counts.dec(id, 1);
@ -998,6 +1095,38 @@ namespace euf {
return true;
}
void ac_plugin::deduplicate(ptr_vector<node>& a, ptr_vector<node>& b) {
if (!m_is_injective)
return;
m_eq_counts.reset();
for (auto n : a)
m_eq_counts.inc(n->id(), 1);
bool has_dup = any_of(b, [&](node* n) { return m_eq_counts[n->id()] > 0; });
if (!has_dup)
return;
std::sort(a.begin(), a.end(), [&](node* x, node* y) { return x->id() < y->id(); });
std::sort(b.begin(), b.end(), [&](node* x, node* y) { return x->id() < y->id(); });
unsigned i = 0, j = 0, in = 0, jn = 0;
for (; i < a.size() && j < b.size(); ) {
if (a[i]->id() == b[j]->id()) {
++i;
++j;
}
else if (a[i]->id() < b[j]->id()) {
a[in++] = a[i++];
}
else {
b[jn++] = b[j++];
}
}
for (; i < a.size(); ++i)
a[in++] = a[i];
for (; j < b.size(); ++j)
b[jn++] = b[j];
a.shrink(in);
b.shrink(jn);
}
//
// simple version based on propagating all shared
// todo: version touching only newly processed shared, and maintaining incremental data-structures.