3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-07-29 15:37:58 +00:00

wip - throttle AC completion, enable congruences over bound bodies

- AC completion which is exposed as an option to the new congruence closure core used roots of E-Graph which gets ordering of monomials out of sync.
- Added injective function handling to AC completion
- Move to model where all equations, also unit to unit are in completion
- throw in first level bound bodies into the E-graph to enable canonization on them.
This commit is contained in:
Nikolaj Bjorner 2025-07-11 12:48:27 +02:00
parent 35b1d09425
commit 0995928f6e
7 changed files with 345 additions and 76 deletions

View file

@ -53,6 +53,14 @@ More notes:
after merge
n' + k = j + k, could be that n' + k < j + k < n + k in term ordering because n' < j, m < n
We filter results of superposition so that the size of monomials in the new equations don't grow.
This filter is used to make the process somewhat tractable. In other words, we never compute a
complete saturation. If l1 = r1 l2 = r2 are used to produce new equation l = r, we ensure
that min(|l|,|r|) <= min(|r1|,|r2|) and max(|l|,|r|) <= max(|l1|,|l2|)
If the operator is injective we simplify equations
xl = xr to l = r
TODOs:
- Efficiency of handling shared terms.
@ -63,6 +71,7 @@ TODOs:
- by an epoch counter that can be updated by the egraph class whenever there is a push/pop.
- store the epoch as a tick on equations and possibly when updating monomials on equations.
--*/
#include "ast/euf/euf_ac_plugin.h"
@ -195,6 +204,20 @@ namespace euf {
return out;
}
std::ostream& ac_plugin::display_monomial_ll(std::ostream& out, ptr_vector<node> const& m) const {
for (auto n : m)
out << n->id() << " ";
return out;
}
std::ostream& ac_plugin::display_equation_ll(std::ostream& out, eq const& e) const {
display_status(out, e.status) << " ";
display_monomial_ll(out, monomial(e.l));
out << "== ";
display_monomial_ll(out, monomial(e.r));
return out;
}
std::ostream& ac_plugin::display_status(std::ostream& out, eq_status s) const {
switch (s) {
case eq_status::is_dead: out << "d"; break;
@ -207,15 +230,14 @@ namespace euf {
std::ostream& ac_plugin::display(std::ostream& out) const {
unsigned i = 0;
for (auto const& eq : m_eqs) {
out << i << ": " << eq.l << " == " << eq.r << ": ";
display_equation(out, eq);
out << "\n";
if (eq.status != eq_status::is_dead)
out << i << ": " << eq_pp_ll(*this, eq) << "\n";
++i;
}
i = 0;
for (auto m : m_monomials) {
out << i << ": ";
display_monomial(out, m);
display_monomial_ll(out, m);
out << "\n";
++i;
}
@ -224,7 +246,7 @@ namespace euf {
continue;
if (n->eqs.empty() && n->shared.empty())
continue;
out << g.bpp(n->n) << " r: " << n->root_id() << " ";
out << g.bpp(n->n) << " r: " << n->id() << " ";
if (!n->eqs.empty()) {
out << "eqs ";
for (auto l : n->eqs)
@ -247,7 +269,6 @@ namespace euf {
TRACE(plugin, tout << g.bpp(l) << " == " << g.bpp(r) << " " << is_op(l) << " " << is_op(r) << "\n");
if (!is_op(l) && !is_op(r))
merge(mk_node(l), mk_node(r), j);
else
init_equation(eq(to_monomial(l), to_monomial(r), j));
}
@ -264,12 +285,20 @@ namespace euf {
void ac_plugin::init_equation(eq const& e) {
m_eqs.push_back(e);
auto& eq = m_eqs.back();
TRACE(plugin, display_equation(tout, e) << "\n");
TRACE(plugin, display_equation_ll(tout, e) << "\n");
deduplicate(monomial(eq.l).m_nodes, monomial(eq.r).m_nodes);
TRACE(plugin, display_equation_ll(tout << "dedup ", e) << "\n");
if (orient_equation(eq)) {
auto& ml = monomial(eq.l);
auto& mr = monomial(eq.r);
if (ml.size() == 1 && mr.size() == 1)
push_merge(ml[0]->n, mr[0]->n, eq.j);
unsigned eq_id = m_eqs.size() - 1;
for (auto n : monomial(eq.l)) {
for (auto n : ml) {
if (!n->root->n->is_marked1()) {
n->root->eqs.push_back(eq_id);
n->root->n->mark1();
@ -280,7 +309,7 @@ namespace euf {
}
}
for (auto n : monomial(eq.r)) {
for (auto n : mr) {
if (!n->root->n->is_marked1()) {
n->root->eqs.push_back(eq_id);
n->root->n->mark1();
@ -291,13 +320,13 @@ namespace euf {
}
}
for (auto n : monomial(eq.l))
for (auto n : ml)
n->root->n->unmark1();
for (auto n : monomial(eq.r))
for (auto n : mr)
n->root->n->unmark1();
TRACE(plugin, display_equation(tout, e) << "\n");
TRACE(plugin, display_equation_ll(tout, e) << "\n");
m_to_simplify_todo.insert(eq_id);
}
else
@ -316,10 +345,10 @@ namespace euf {
else {
sort(ml);
sort(mr);
for (unsigned i = ml.size(); i-- > 0;) {
if (ml[i]->root_id() == mr[i]->root_id())
for (unsigned i = 0; i < ml.size(); ++i) {
if (ml[i]->id() == mr[i]->id())
continue;
if (ml[i]->root_id() < mr[i]->root_id())
if (ml[i]->id() < mr[i]->id())
std::swap(e.l, e.r);
return true;
}
@ -327,15 +356,38 @@ namespace euf {
}
}
bool ac_plugin::is_equation_oriented(eq const& e) const {
auto& ml = monomial(e.l);
auto& mr = monomial(e.r);
if (ml.size() > mr.size())
return true;
if (ml.size() < mr.size())
return false;
else {
if (!is_sorted(ml))
return false;
if (!is_sorted(mr))
return false;
for (unsigned i = 0; i < ml.size(); ++i) {
if (ml[i]->id() == mr[i]->id())
continue;
if (ml[i]->id() < mr[i]->id())
return false;
return true;
}
return false;
}
}
void ac_plugin::sort(monomial_t& m) {
std::sort(m.begin(), m.end(), [&](node* a, node* b) { return a->root_id() < b->root_id(); });
std::sort(m.begin(), m.end(), [&](node* a, node* b) { return a->id() < b->id(); });
}
bool ac_plugin::is_sorted(monomial_t const& m) const {
if (m.m_bloom.m_tick == m_tick)
return true;
for (unsigned i = m.size(); i-- > 1; )
if (m[i - 1]->root_id() > m[i]->root_id())
if (m[i - 1]->id() > m[i]->id())
return false;
return true;
}
@ -346,7 +398,7 @@ namespace euf {
return bloom.m_filter;
bloom.m_filter = 0;
for (auto n : m)
bloom.m_filter |= (1ull << (n->root_id() % 64ull));
bloom.m_filter |= (1ull << (n->id() % 64ull));
if (!is_sorted(m))
sort(m);
bloom.m_tick = m_tick;
@ -367,25 +419,29 @@ namespace euf {
if (bloom.m_tick != m_tick) {
bloom.m_filter = 0;
for (auto n : m)
bloom.m_filter |= (1ull << (n->root_id() % 64ull));
bloom.m_filter |= (1ull << (n->id() % 64ull));
bloom.m_tick = m_tick;
}
auto f2 = bloom.m_filter;
return (filter(subset) | f2) == f2;
}
void ac_plugin::merge(node* root, node* other, justification j) {
TRACE(plugin, tout << root << " == " << other << " num shared " << other->shared.size() << "\n");
for (auto n : equiv(other))
n->root = root;
m_merge_trail.push_back({ other, root->shared.size(), root->eqs.size() });
for (auto eq_id : other->eqs)
void ac_plugin::merge(node* a, node* b, justification j) {
TRACE(plugin, tout << a << " == " << b << " num shared " << b->shared.size() << "\n");
if (a == b)
return;
if (a->id() < b->id())
std::swap(a, b);
for (auto n : equiv(a))
n->root = b;
m_merge_trail.push_back({ a, b->shared.size(), b->eqs.size() });
for (auto eq_id : a->eqs)
set_status(eq_id, eq_status::to_simplify);
for (auto m : other->shared)
for (auto m : a->shared)
m_shared_todo.insert(m);
root->shared.append(other->shared);
root->eqs.append(other->eqs);
std::swap(root->next, other->next);
b->shared.append(a->shared);
b->eqs.append(a->eqs);
std::swap(b->next, a->next);
push_undo(is_merge_node);
++m_tick;
}
@ -427,7 +483,7 @@ namespace euf {
args.push_back(arg->root->n->get_expr());
}
auto n = m.mk_app(m_fid, m_op, args.size(), args.data());
return g.mk(n, 0, nodes.size(), nodes.data());
return g.find(n) ? g.find(n) : g.mk(n, 0, nodes.size(), nodes.data());
}
ac_plugin::node* ac_plugin::node::mk(region& r, enode* n) {
@ -460,7 +516,7 @@ namespace euf {
if (eq_id == UINT_MAX)
break;
TRACE(plugin, tout << "propagate " << eq_id << ": " << eq_pp(*this, m_eqs[eq_id]) << "\n");
TRACE(plugin, tout << "propagate " << eq_id << ": " << eq_pp_ll(*this, m_eqs[eq_id]) << "\n");
// simplify eq using processed
TRACE(plugin,
@ -470,6 +526,12 @@ namespace euf {
if (is_processed(other_eq) && backward_simplify(eq_id, other_eq))
goto loop_start;
auto& eq = m_eqs[eq_id];
deduplicate(monomial(eq.l).m_nodes, monomial(eq.r).m_nodes);
if (monomial(eq.l).size() == 0) {
set_status(eq_id, eq_status::is_dead);
continue;
}
set_status(eq_id, eq_status::processed);
// simplify processed using eq
@ -478,9 +540,13 @@ namespace euf {
forward_simplify(eq_id, other_eq);
// superpose, create new equations
unsigned new_eqs = 0;
for (auto other_eq : superpose_iterator(eq_id))
if (is_processed(other_eq))
superpose(eq_id, other_eq);
new_eqs += superpose(eq_id, other_eq);
(void)new_eqs;
TRACE(plugin, tout << "added eqs " << new_eqs << "\n");
// simplify to_simplify using eq
for (auto other_eq : forward_iterator(eq_id))
@ -633,7 +699,7 @@ namespace euf {
void ac_plugin::init_ref_counts(ptr_vector<node> const& monomial, ref_counts& counts) const {
counts.reset();
for (auto n : monomial)
counts.inc(n->root_id(), 1);
counts.inc(n->id(), 1);
}
bool ac_plugin::is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const {
@ -660,7 +726,7 @@ namespace euf {
auto& src = m_eqs[src_eq]; // src_r_counts, src_l_counts are initialized
auto& dst = m_eqs[dst_eq];
TRACE(plugin, tout << "forward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << "\n");
TRACE(plugin, tout << "forward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << "\n");
if (forward_subsumes(src_eq, dst_eq)) {
@ -685,7 +751,7 @@ namespace euf {
// := src_rhs + elements from dst_rhs that are in excess of src_lhs
unsigned num_overlap = 0;
for (auto n : monomial(dst.r)) {
unsigned id = n->root_id();
unsigned id = n->id();
unsigned dst_count = m_dst_r_counts[id];
unsigned src_count = m_src_l_counts[id];
if (dst_count > src_count) {
@ -726,7 +792,7 @@ namespace euf {
//
// dst_ids, dst_count contain rhs of dst_eq
//
TRACE(plugin, tout << "backward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n");
TRACE(plugin, tout << "backward simplify " << eq_pp_ll(*this, src) << " " << eq_pp_ll(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n");
if (backward_subsumes(src_eq, dst_eq)) {
TRACE(plugin, tout << "backward subsumed\n");
@ -782,7 +848,7 @@ namespace euf {
SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts));
// add difference betwen dst.l and src.l to both src.l, src.r
for (auto n : monomial(dst.l)) {
unsigned id = n->root_id();
unsigned id = n->id();
SASSERT(m_dst_l_counts[id] >= m_src_l_counts[id]);
unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id];
if (diff > 0) {
@ -792,7 +858,7 @@ namespace euf {
}
// now dst.r and src.r should align and have the same elements.
// since src.r is a subset of dst.r we iterate over dst.r
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
}
// src_l_counts, src_r_counts are initialized for src.l, src.r
@ -815,7 +881,7 @@ namespace euf {
SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts));
SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts));
for (auto n : monomial(src.l)) {
unsigned id = n->root_id();
unsigned id = n->id();
SASSERT(m_src_l_counts[id] <= m_dst_l_counts[id]);
unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id];
if (diff == 0)
@ -826,7 +892,7 @@ namespace euf {
m_dst_r_counts.dec(id, diff);
}
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
}
void ac_plugin::rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_counts, ptr_vector<node>& dst) {
@ -837,14 +903,17 @@ namespace euf {
SASSERT(is_correct_ref_count(dst, dst_counts));
SASSERT(&src_r.m_nodes != &dst);
unsigned sz = dst.size(), j = 0;
bool change = false;
for (unsigned i = 0; i < sz; ++i) {
auto* n = dst[i];
unsigned id = n->root_id();
unsigned id = n->id();
unsigned dst_count = dst_counts[id];
unsigned src_count = src_l[id];
SASSERT(dst_count > 0);
if (src_count == 0)
if (src_count == 0) {
dst[j++] = n;
}
else if (src_count < dst_count) {
dst[j++] = n;
dst_counts.dec(id, 1);
@ -857,25 +926,41 @@ namespace euf {
// rewrite monomial to normal form.
bool ac_plugin::reduce(ptr_vector<node>& m, justification& j) {
bool change = false;
unsigned sz = m.size();
unsigned jj = 0;
//verbose_stream() << "start\n";
do {
init_loop:
//verbose_stream() << "loop " << jj++ << "\n";
if (m.size() == 1)
return change;
bloom b;
init_ref_counts(m, m_m_counts);
unsigned k = 0;
for (auto n : m) {
//verbose_stream() << "inner loop " << k++ << "\n";
for (auto eq : n->root->eqs) {
if (!is_processed(eq))
continue;
auto& src = m_eqs[eq];
if (!is_equation_oriented(src)) {
continue;
if (!orient_equation(src))
continue;
// deduplicate(src.l, src.r);
}
if (!can_be_subset(monomial(src.l), m, b))
continue;
if (!is_subset(m_m_counts, m_eq_counts, monomial(src.l)))
continue;
TRACE(plugin, display_equation(tout << "reduce ", src) << "\n");
TRACE(plugin, display_equation_ll(tout << "reduce ", src) << "\n");
SASSERT(is_correct_ref_count(monomial(src.l), m_eq_counts));
//display_equation_ll(std::cout << "reduce ", src) << ": ";
//display_monomial_ll(std::cout, m);
rewrite1(m_eq_counts, monomial(src.r), m_m_counts, m);
//display_monomial_ll(std::cout << " -> ", m) << "\n";
j = join(j, eq);
change = true;
goto init_loop;
@ -883,6 +968,7 @@ namespace euf {
}
}
while (false);
VERIFY(sz >= m.size());
return change;
}
@ -917,13 +1003,17 @@ namespace euf {
}
void ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) {
bool ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) {
if (src_eq == dst_eq)
return;
return false;
auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq];
TRACE(plugin, tout << "superpose: "; display_equation(tout, src); tout << " "; display_equation(tout, dst); tout << "\n";);
unsigned max_left = std::max(monomial(src.l).size(), monomial(dst.l).size());
unsigned min_right = std::max(monomial(src.r).size(), monomial(dst.r).size());
TRACE(plugin, tout << "superpose: "; display_equation_ll(tout, src); tout << " "; display_equation_ll(tout, dst); tout << "\n";);
// AB -> C, AD -> E => BE ~ CD
// m_src_ids, m_src_counts contains information about src (call it AD -> E)
m_dst_l_counts.reset();
@ -938,7 +1028,7 @@ namespace euf {
// compute BE, initialize dst_ids, dst_counts
bool overlap = false;
for (auto n : monomial(dst.l)) {
unsigned id = n->root_id();
unsigned id = n->id();
m_dst_l_counts.inc(id, 1);
if (m_src_l_counts[id] < m_dst_l_counts[id])
m_src_r.push_back(n);
@ -947,12 +1037,12 @@ namespace euf {
if (!overlap) {
m_src_r.shrink(src_r_size);
return;
return false;
}
// compute CD
for (auto n : monomial(src.l)) {
unsigned id = n->root_id();
unsigned id = n->id();
if (m_dst_l_counts[id] > 0)
m_dst_l_counts.dec(id, 1);
else
@ -961,21 +1051,28 @@ namespace euf {
if (are_equal(m_src_r, m_dst_r)) {
m_src_r.shrink(src_r_size);
return;
return false;
}
TRACE(plugin, tout << m_pp(*this, m_src_r) << "== " << m_pp(*this, m_dst_r) << "\n";);
TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";);
justification j = justify_rewrite(src_eq, dst_eq);
deduplicate(m_src_r, m_dst_r);
reduce(m_dst_r, j);
reduce(m_src_r, j);
if (m_src_r.size() == 1 && m_dst_r.size() == 1)
push_merge(m_src_r[0]->n, m_dst_r[0]->n, j);
else
TRACE(plugin, tout << "superpose result: " << m_pp_ll(*this, m_src_r) << "== " << m_pp_ll(*this, m_dst_r) << "\n";);
bool added_eq = false;
unsigned max_left_new = std::max(m_src_r.size(), m_dst_r.size());
unsigned min_right_new = std::min(m_src_r.size(), m_dst_r.size());
if (max_left_new <= max_left && min_right_new <= min_right) {
init_equation(eq(to_monomial(m_src_r), to_monomial(m_dst_r), j));
added_eq = true;
}
m_src_r.reset();
m_src_r.append(monomial(src.r).m_nodes);
return added_eq;
}
bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) {
@ -987,10 +1084,10 @@ namespace euf {
return false;
m_eq_counts.reset();
for (auto n : a)
m_eq_counts.inc(n->root_id(), 1);
m_eq_counts.inc(n->id(), 1);
for (auto n : b) {
unsigned id = n->root_id();
unsigned id = n->id();
if (m_eq_counts[id] == 0)
return false;
m_eq_counts.dec(id, 1);
@ -998,6 +1095,38 @@ namespace euf {
return true;
}
void ac_plugin::deduplicate(ptr_vector<node>& a, ptr_vector<node>& b) {
if (!m_is_injective)
return;
m_eq_counts.reset();
for (auto n : a)
m_eq_counts.inc(n->id(), 1);
bool has_dup = any_of(b, [&](node* n) { return m_eq_counts[n->id()] > 0; });
if (!has_dup)
return;
std::sort(a.begin(), a.end(), [&](node* x, node* y) { return x->id() < y->id(); });
std::sort(b.begin(), b.end(), [&](node* x, node* y) { return x->id() < y->id(); });
unsigned i = 0, j = 0, in = 0, jn = 0;
for (; i < a.size() && j < b.size(); ) {
if (a[i]->id() == b[j]->id()) {
++i;
++j;
}
else if (a[i]->id() < b[j]->id()) {
a[in++] = a[i++];
}
else {
b[jn++] = b[j++];
}
}
for (; i < a.size(); ++i)
a[in++] = a[i];
for (; j < b.size(); ++j)
b[jn++] = b[j];
a.shrink(in);
b.shrink(jn);
}
//
// simple version based on propagating all shared
// todo: version touching only newly processed shared, and maintaining incremental data-structures.

View file

@ -46,7 +46,7 @@ namespace euf {
unsigned_vector shared; // shared occurrences
unsigned_vector eqs; // equality occurrences
unsigned root_id() const { return root->n->get_id(); }
unsigned id() const { return root->n->get_id(); }
static node* mk(region& r, enode* n);
};
@ -117,7 +117,7 @@ namespace euf {
if (!p.is_sorted(m))
p.sort(m);
for (auto* n : m)
h = combine_hash(h, n->root_id());
h = combine_hash(h, n->id());
return h;
}
};
@ -130,7 +130,7 @@ namespace euf {
auto const& m2 = p.monomial(j);
if (m1.size() != m2.size()) return false;
for (unsigned k = 0; k < m1.size(); ++k)
if (m1[k]->root_id() != m2[k]->root_id())
if (m1[k]->id() != m2[k]->id())
return false;
return true;
}
@ -139,6 +139,7 @@ namespace euf {
theory_id m_fid = 0;
decl_kind m_op = null_decl_kind;
func_decl* m_decl = nullptr;
bool m_is_injective = false;
vector<eq> m_eqs;
ptr_vector<node> m_nodes;
bool_vector m_shared_nodes;
@ -208,7 +209,8 @@ namespace euf {
void forward_simplify(unsigned eq_id, unsigned using_eq);
bool backward_simplify(unsigned eq_id, unsigned using_eq);
void superpose(unsigned src_eq, unsigned dst_eq);
bool superpose(unsigned src_eq, unsigned dst_eq);
void deduplicate(ptr_vector<node>& a, ptr_vector<node>& b);
ptr_vector<node> m_src_r, m_src_l, m_dst_r, m_dst_l;
@ -236,6 +238,7 @@ namespace euf {
void compress_eq_occurs(unsigned eq_id);
// check that src is a subset of dst, where dst_counts are precomputed
bool is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src);
bool is_equation_oriented(eq const& e) const;
// check that dst is a superset of dst, where src_counts are precomputed
bool is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst);
@ -262,6 +265,9 @@ namespace euf {
std::ostream& display_monomial(std::ostream& out, monomial_t const& m) const { return display_monomial(out, m.m_nodes); }
std::ostream& display_monomial(std::ostream& out, ptr_vector<node> const& m) const;
std::ostream& display_equation(std::ostream& out, eq const& e) const;
std::ostream& display_monomial_ll(std::ostream& out, monomial_t const& m) const { return display_monomial_ll(out, m.m_nodes); }
std::ostream& display_monomial_ll(std::ostream& out, ptr_vector<node> const& m) const;
std::ostream& display_equation_ll(std::ostream& out, eq const& e) const;
std::ostream& display_status(std::ostream& out, eq_status s) const;
@ -271,6 +277,8 @@ namespace euf {
ac_plugin(egraph& g, func_decl* f);
void set_injective() { m_is_injective = true; }
theory_id get_id() const override { return m_fid; }
void register_node(enode* n) override;
@ -288,20 +296,36 @@ namespace euf {
void set_undo(std::function<void(void)> u) { m_undo_notify = u; }
struct eq_pp {
ac_plugin& p; eq const& e;
eq_pp(ac_plugin& p, eq const& e) : p(p), e(e) {};
eq_pp(ac_plugin& p, unsigned eq_id): p(p), e(p.m_eqs[eq_id]) {}
ac_plugin const& p; eq const& e;
eq_pp(ac_plugin const& p, eq const& e) : p(p), e(e) {};
eq_pp(ac_plugin const& p, unsigned eq_id): p(p), e(p.m_eqs[eq_id]) {}
std::ostream& display(std::ostream& out) const { return p.display_equation(out, e); }
};
struct eq_pp_ll {
ac_plugin const& p; eq const& e;
eq_pp_ll(ac_plugin const& p, eq const& e) : p(p), e(e) {};
eq_pp_ll(ac_plugin const& p, unsigned eq_id) : p(p), e(p.m_eqs[eq_id]) {}
std::ostream& display(std::ostream& out) const { return p.display_equation_ll(out, e); }
};
struct m_pp {
ac_plugin& p; ptr_vector<node> const& m;
m_pp(ac_plugin& p, monomial_t const& m) : p(p), m(m.m_nodes) {}
m_pp(ac_plugin& p, ptr_vector<node> const& m) : p(p), m(m) {}
ac_plugin const& p; ptr_vector<node> const& m;
m_pp(ac_plugin const& p, monomial_t const& m) : p(p), m(m.m_nodes) {}
m_pp(ac_plugin const& p, ptr_vector<node> const& m) : p(p), m(m) {}
std::ostream& display(std::ostream& out) const { return p.display_monomial(out, m); }
};
struct m_pp_ll {
ac_plugin const& p; ptr_vector<node> const& m;
m_pp_ll(ac_plugin const& p, monomial_t const& m) : p(p), m(m.m_nodes) {}
m_pp_ll(ac_plugin const& p, ptr_vector<node> const& m) : p(p), m(m) {}
std::ostream& display(std::ostream& out) const { return p.display_monomial_ll(out, m); }
};
};
inline std::ostream& operator<<(std::ostream& out, ac_plugin::eq_pp const& d) { return d.display(out); }
inline std::ostream& operator<<(std::ostream& out, ac_plugin::eq_pp_ll const& d) { return d.display(out); }
inline std::ostream& operator<<(std::ostream& out, ac_plugin::m_pp const& d) { return d.display(out); }
inline std::ostream& operator<<(std::ostream& out, ac_plugin::m_pp_ll const& d) { return d.display(out); }
}

View file

@ -30,6 +30,7 @@ namespace euf {
m_add.set_undo(uadd);
std::function<void(void)> umul = [&]() { m_undo.push_back(undo_t::undo_mul); };
m_mul.set_undo(umul);
m_add.set_injective();
}
void arith_plugin::register_node(enode* n) {

View file

@ -26,11 +26,15 @@ namespace euf {
}
void plugin::push_merge(enode* a, enode* b, justification j) {
if (a->get_root() == b->get_root())
return; // already merged
TRACE(euf, tout << "push-merge " << g.bpp(a) << " == " << g.bpp(b) << " " << j << "\n");
g.push_merge(a, b, j);
}
void plugin::push_merge(enode* a, enode* b) {
if (a->get_root() == b->get_root())
return; // already merged
TRACE(plugin, tout << g.bpp(a) << " == " << g.bpp(b) << "\n");
g.push_merge(a, b, justification::axiom(get_id()));
}

View file

@ -35,6 +35,8 @@ void expr_abstract(ast_manager& m, unsigned base, unsigned num_bound, expr* cons
inline expr_ref expr_abstract(ast_manager& m, unsigned base, unsigned num_bound, expr* const* bound, expr* n) { expr_ref r(m); expr_abstract(m, base, num_bound, bound, n, r); return r; }
inline expr_ref expr_abstract(expr_ref_vector const& bound, expr* n) { return expr_abstract(bound.m(), 0, bound.size(), bound.data(), n); }
inline expr_ref expr_abstract(app_ref_vector const& bound, expr* n) { return expr_abstract(bound.m(), 0, bound.size(), (expr*const*)bound.data(), n); }
inline expr_ref expr_abstract(ast_manager& m, ptr_vector<expr> const& bound, expr* n) { return expr_abstract(m, 0, bound.size(), bound.data(), n); }
expr_ref mk_forall(ast_manager& m, unsigned num_bound, app* const* bound, expr* n);
expr_ref mk_exists(ast_manager& m, unsigned num_bound, app* const* bound, expr* n);
inline expr_ref mk_forall(ast_manager& m, app* b, expr* n) { return mk_forall(m, 1, &b, n); }

View file

@ -50,6 +50,7 @@ Mam optimization?
#include "ast/ast_pp.h"
#include "ast/ast_util.h"
#include "ast/expr_abstract.h"
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_arith_plugin.h"
#include "ast/euf/euf_bv_plugin.h"
@ -276,6 +277,9 @@ namespace euf {
expr_ref y1(y, m);
m_rewriter(x1);
m_rewriter(y1);
add_quantifiers(x1);
add_quantifiers(y1);
enode* a = mk_enode(x1);
enode* b = mk_enode(y1);
if (a->get_root() == b->get_root())
@ -283,25 +287,40 @@ namespace euf {
m_egraph.merge(a, b, to_ptr(push_pr_dep(pr, d)));
add_children(a);
add_children(b);
auto a1 = mk_enode(x);
if (a1->get_root() != a->get_root()) {
m_egraph.merge(a, a1, nullptr);
add_children(a1);
}
auto b1 = mk_enode(y);
if (b1->get_root() != b->get_root()) {
m_egraph.merge(b, b1, nullptr);
add_children(b1);
}
m_should_propagate = true;
if (m_side_condition_solver)
m_side_condition_solver->add_constraint(f, pr, d);
IF_VERBOSE(1, verbose_stream() << "eq: " << mk_pp(x1, m) << " == " << mk_pp(y1, m) << "\n");
}
else if (m.is_not(f, f)) {
enode* n = mk_enode(f);
if (m.is_false(n->get_root()->get_expr()))
return;
add_quantifiers(f);
auto j = to_ptr(push_pr_dep(pr, d));
m_egraph.new_diseq(n, j);
add_children(n);
m_should_propagate = true;
if (m_side_condition_solver)
m_side_condition_solver->add_constraint(f, pr, d);
IF_VERBOSE(1, verbose_stream() << "not: " << mk_pp(f, m) << "\n");
}
else {
enode* n = mk_enode(f);
if (m.is_true(n->get_root()->get_expr()))
return;
IF_VERBOSE(1, verbose_stream() << "fml: " << mk_pp(f, m) << "\n");
m_egraph.merge(n, m_tt, to_ptr(push_pr_dep(pr, d)));
add_children(n);
if (is_forall(f)) {
@ -331,11 +350,64 @@ namespace euf {
m_q2dep.insert(q, { pr, d });
get_trail().push(insert_obj_map(m_q2dep, q));
}
add_rule(f, pr, d);
if (!is_forall(f) && !m.is_implies(f) && m_side_condition_solver)
if (!is_forall(f) && !m.is_implies(f)) {
add_quantifiers(f);
if (m_side_condition_solver)
m_side_condition_solver->add_constraint(f, pr, d);
}
}
}
void completion::add_quantifiers(expr* f) {
if (!has_quantifiers(f))
return;
ptr_vector<expr> bound;
add_quantifiers(bound, f);
}
void completion::add_quantifiers(ptr_vector<expr>& bound, expr* f) {
if (!has_quantifiers(f))
return;
ptr_vector<expr> todo;
todo.push_back(f);
expr_fast_mark1 visited;
for (unsigned j = 0; j < todo.size(); ++j) {
expr* t = todo[j];
if (visited.is_marked(t))
continue;
visited.mark(t);
if (!has_quantifiers(t))
continue;
if (is_app(t)) {
for (auto arg : *to_app(t))
todo.push_back(arg);
}
else if (is_quantifier(t)) {
auto q = to_quantifier(t);
auto nd = q->get_num_decls();
verbose_stream() << "bind " << mk_pp(q, m) << "\n";
for (unsigned i = 0; i < nd; ++i) {
auto name = std::string("bound!") + std::to_string(bound.size());
auto b = m.mk_const(name, q->get_decl_sort(i));
// TODO: persist bound variables withn scope to avoid reference count crashes
bound.push_back(b);
}
expr_ref inst = var_subst(m)(q->get_expr(), bound);
if (!m_egraph.find(inst)) {
m_closures.insert(q, { bound, inst });
get_trail().push(insert_map(m_closures, q));
mk_enode(inst);
// TODO: handle nested quantifiers after m_closures is updated to
// index on sort declaration prefix together with quantifier
// add_quantifiers(bound, inst);
}
bound.shrink(bound.size() - nd);
}
}
}
lbool completion::eval_cond(expr* f, proof_ref& pr, expr_dependency*& d) {
auto n = mk_enode(f);
@ -760,6 +832,9 @@ namespace euf {
}
expr_ref completion::canonize(expr* f, proof_ref& pr, expr_dependency_ref& d) {
if (is_quantifier(f))
return expr_ref(canonize(to_quantifier(f), pr, d), m);
if (!is_app(f))
return expr_ref(f, m); // todo could normalize ground expressions under quantifiers
@ -787,8 +862,29 @@ namespace euf {
return r;
}
expr_ref completion::canonize(quantifier* q, proof_ref& pr, expr_dependency_ref& d) {
std::pair<ptr_vector<expr>, expr*> clos;
if (!m_closures.find(q, clos))
return expr_ref(q, m);
expr* body = clos.second;
expr_ref new_body = canonize(body, pr, d);
expr_ref result = expr_abstract(m, clos.first, new_body);
if (m.proofs_enabled()) {
// add proof rule
//
// body = new_body
// ---------------------------
// Q x . body = Q x . new_body
NOT_IMPLEMENTED_YET();
}
return result;
}
expr* completion::get_canonical(expr* f, proof_ref& pr, expr_dependency_ref& d) {
enode* n = m_egraph.find(f);
if (!n) verbose_stream() << "not found " << f->get_id() << " " << mk_pp(f, m) << "\n";
enode* r = n->get_root();
d = m.mk_join(d, explain_eq(n, r));
d = m.mk_join(d, m_deps.get(r->get_id(), nullptr));

View file

@ -171,6 +171,19 @@ namespace euf {
proof* get_canonical_proof(enode* n);
void set_canonical(enode* n, expr* e, proof* pr);
void add_constraint(expr*f, proof* pr, expr_dependency* d);
// Enable equality propagation inside of quantifiers
// add quantifier bodies as closure terms to the E-graph.
// use fresh variables for bound variables, but such that the fresh variables are
// the same when the quantifier prefix is the same.
// Thus, we are going to miss equalities of quantifier bodies
// if the prefixes are different but the bodies are the same.
// Closure terms are re-abstracted by the canonizer.
void add_quantifiers(ptr_vector<expr>& bound, expr* t);
void add_quantifiers(expr* t);
expr_ref canonize(quantifier* q, proof_ref& pr, expr_dependency_ref& d);
obj_map<quantifier, std::pair<ptr_vector<expr>, expr*>> m_closures;
expr_dependency* explain_eq(enode* a, enode* b);
proof_ref prove_eq(enode* a, enode* b);
proof_ref prove_conflict();