3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

fixes to AC plugin

This commit is contained in:
Nikolaj Bjorner 2023-11-28 12:50:43 -08:00
parent 14483dcd6e
commit a805e1f27d
26 changed files with 887 additions and 293 deletions

View file

@ -6,7 +6,9 @@ z3_add_component(euf
euf_egraph.cpp
euf_enode.cpp
euf_etable.cpp
euf_justification.cpp
euf_plugin.cpp
euf_specrel_plugin.cpp
COMPONENT_DEPENDENCIES
ast
util

View file

@ -67,6 +67,7 @@ TODOs:
#include "ast/euf/euf_ac_plugin.h"
#include "ast/euf/euf_egraph.h"
#include "ast/ast_pp.h"
namespace euf {
@ -74,7 +75,18 @@ namespace euf {
plugin(g), m_fid(fid), m_op(op),
m_dep_manager(get_region()),
m_hash(*this), m_eq(*this), m_monomial_table(m_hash, m_eq)
{}
{
g.set_th_propagates_diseqs(m_fid);
}
ac_plugin::ac_plugin(egraph& g, func_decl* f) :
plugin(g), m_decl(f), m_fid(f->get_family_id()),
m_dep_manager(get_region()),
m_hash(*this), m_eq(*this), m_monomial_table(m_hash, m_eq)
{
if (m_fid != null_family_id)
g.set_th_propagates_diseqs(m_fid);
}
void ac_plugin::register_node(enode* n) {
if (is_op(n))
@ -85,16 +97,19 @@ namespace euf {
}
void ac_plugin::register_shared(enode* n) {
if (m_shared_nodes.get(n->get_id(), false))
return;
auto m = to_monomial(n);
auto const& ns = monomial(m);
for (auto arg : ns) {
arg->shared.push_back(m);
m_node_trail.push_back(arg);
push_undo(is_add_shared);
push_undo(is_add_shared_index);
}
m_shared_nodes.setx(n->get_id(), true, false);
sort(monomial(m));
m_shared_todo.insert(m_shared.size());
m_shared.push_back({ n, m, justification::axiom() });
m_shared_todo.insert(m);
push_undo(is_register_shared);
}
@ -103,11 +118,6 @@ namespace euf {
m_undo.pop_back();
switch (k) {
case is_add_eq: {
auto const& eq = m_eqs.back();
for (auto* n : monomial(eq.l))
n->eqs.pop_back();
for (auto* n : monomial(eq.r))
n->eqs.pop_back();
m_eqs.pop_back();
break;
}
@ -138,13 +148,21 @@ namespace euf {
m_update_eq_trail.pop_back();
break;
}
case is_add_shared: {
case is_add_shared_index: {
auto n = m_node_trail.back();
m_node_trail.pop_back();
n->shared.pop_back();
break;
}
case is_add_eq_index: {
auto n = m_node_trail.back();
m_node_trail.pop_back();
n->eqs.pop_back();
break;
}
case is_register_shared: {
auto s = m_shared.back();
m_shared_nodes[s.n->get_id()] = false;
m_shared.pop_back();
break;
}
@ -159,9 +177,13 @@ namespace euf {
}
}
std::ostream& ac_plugin::display_monomial(std::ostream& out, monomial_t const& m) const {
for (auto n : m)
out << g.bpp(n->n) << " ";
std::ostream& ac_plugin::display_monomial(std::ostream& out, ptr_vector<node> const& m) const {
for (auto n : m) {
if (n->n->num_args() == 0)
out << mk_pp(n->n->get_expr(), g.get_manager()) << " ";
else
out << g.bpp(n->n) << " ";
}
return out;
}
@ -200,6 +222,8 @@ namespace euf {
for (auto n : m_nodes) {
if (!n)
continue;
if (n->eqs.empty() && n->shared.empty())
continue;
out << g.bpp(n->n) << " r: " << n->root_id() << " ";
if (!n->eqs.empty()) {
out << "eqs ";
@ -216,25 +240,57 @@ namespace euf {
return out;
}
void ac_plugin::merge_eh(enode* l, enode* r, justification j) {
void ac_plugin::merge_eh(enode* l, enode* r) {
if (l == r)
return;
auto j = justification::equality(l, r);
if (!is_op(l) && !is_op(r))
merge(mk_node(l), mk_node(r), j);
else
init_equation(eq(to_monomial(l), to_monomial(r), j));
}
void ac_plugin::diseq_eh(enode* eq) {
SASSERT(g.get_manager().is_eq(eq->get_expr()));
enode* a = eq->get_arg(0), * b = eq->get_arg(1);
a = a->get_closest_th_node(m_fid);
b = b->get_closest_th_node(m_fid);
SASSERT(a && b);
register_shared(a);
register_shared(b);
}
void ac_plugin::init_equation(eq const& e) {
m_eqs.push_back(e);
auto& eq = m_eqs.back();
if (orient_equation(eq)) {
push_undo(is_add_eq);
unsigned eq_id = m_eqs.size() - 1;
for (auto n : monomial(eq.l))
n->eqs.push_back(eq_id);
for (auto n : monomial(eq.l)) {
if (!n->root->n->is_marked1()) {
n->root->eqs.push_back(eq_id);
n->root->n->mark1();
push_undo(is_add_eq_index);
m_node_trail.push_back(n->root);
}
}
for (auto n : monomial(eq.r)) {
if (!n->root->n->is_marked1()) {
n->root->eqs.push_back(eq_id);
n->root->n->mark1();
push_undo(is_add_eq_index);
m_node_trail.push_back(n->root);
}
}
for (auto n : monomial(eq.l))
n->root->n->unmark1();
for (auto n : monomial(eq.r))
n->eqs.push_back(eq_id);
n->root->n->unmark1();
m_to_simplify_todo.insert(eq_id);
}
else
@ -298,6 +354,19 @@ namespace euf {
return (f1 | f2) == f2;
}
bool ac_plugin::can_be_subset(monomial_t& subset, ptr_vector<node> const& m, bloom& bloom) {
if (subset.size() > m.size())
return false;
if (bloom.m_tick != m_tick) {
bloom.m_filter = 0;
for (auto n : m)
bloom.m_filter |= (1ull << (n->root_id() % 64ull));
bloom.m_tick = m_tick;
}
auto f2 = bloom.m_filter;
return (filter(subset) | f2) == f2;
}
void ac_plugin::merge(node* root, node* other, justification j) {
for (auto n : equiv(other))
n->root = root;
@ -326,15 +395,10 @@ namespace euf {
ns.push_back(n);
for (unsigned i = 0; i < ns.size(); ++i) {
n = ns[i];
if (is_op(n)) {
ns.append(n->num_args(), n->args());
ns[i] = ns.back();
ns.pop_back();
--i;
}
else {
m.push_back(mk_node(n));
}
if (is_op(n))
ns.append(n->num_args(), n->args());
else
m.push_back(mk_node(n));
}
return to_monomial(n, m);
}
@ -367,7 +431,6 @@ namespace euf {
}
void ac_plugin::propagate() {
TRACE("plugin", display(tout));
while (true) {
loop_start:
unsigned eq_id = pick_next_eq();
@ -377,10 +440,14 @@ namespace euf {
TRACE("plugin", tout << "propagate " << eq_id << ": " << eq_pp(*this, m_eqs[eq_id]) << "\n");
// simplify eq using processed
for (auto other_eq : backward_iterator(eq_id))
TRACE("plugin", tout << "backward iterator " << eq_id << " vs " << other_eq << " " << is_processed(other_eq) << "\n");
for (auto other_eq : backward_iterator(eq_id))
if (is_processed(other_eq) && backward_simplify(eq_id, other_eq))
goto loop_start;
set_status(eq_id, eq_status::processed);
// simplify processed using eq
for (auto other_eq : forward_iterator(eq_id))
if (is_processed(other_eq))
@ -395,10 +462,10 @@ namespace euf {
for (auto other_eq : forward_iterator(eq_id))
if (is_to_simplify(other_eq))
forward_simplify(eq_id, other_eq);
set_status(eq_id, eq_status::processed);
}
propagate_shared();
CTRACE("plugin", !m_shared.empty() || !m_eqs.empty(), display(tout));
}
unsigned ac_plugin::pick_next_eq() {
@ -456,6 +523,8 @@ namespace euf {
auto const& eq = m_eqs[eq_id];
init_ref_counts(monomial(eq.r), m_dst_r_counts);
init_ref_counts(monomial(eq.l), m_dst_l_counts);
m_dst_r.reset();
m_dst_r.append(monomial(eq.r).m_nodes);
init_subset_iterator(eq_id, monomial(eq.r));
return m_eq_occurs;
}
@ -479,12 +548,20 @@ namespace euf {
node* max_n = nullptr;
bool has_two = false;
for (auto n : m)
if (n->root->eqs.size() > max_use)
if (n->root->eqs.size() >= max_use)
has_two |= max_n && (max_n != n->root), max_n = n->root, max_use = n->root->eqs.size();
m_eq_occurs.reset();
for (auto n : m)
if (n->root != max_n && has_two)
if (has_two) {
for (auto n : m)
if (n->root != max_n)
m_eq_occurs.append(n->root->eqs);
}
else {
for (auto n : m) {
m_eq_occurs.append(n->root->eqs);
break;
}
}
compress_eq_occurs(eq_id);
}
@ -525,10 +602,26 @@ namespace euf {
return min_n->eqs;
}
void ac_plugin::init_ref_counts(monomial_t const& monomial, ref_counts& counts) {
counts.reset();
for (auto n : monomial)
counts.inc(n->root_id(), 1);
void ac_plugin::init_ref_counts(monomial_t const& monomial, ref_counts& counts) const {
init_ref_counts(monomial.m_nodes, counts);
}
void ac_plugin::init_ref_counts(ptr_vector<node> const& monomial, ref_counts& counts) const {
counts.reset();
for (auto n : monomial)
counts.inc(n->root_id(), 1);
}
bool ac_plugin::is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const {
return is_correct_ref_count(m.m_nodes, counts);
}
bool ac_plugin::is_correct_ref_count(ptr_vector<node> const& m, ref_counts const& counts) const {
ref_counts check;
init_ref_counts(m, check);
return
all_of(counts, [&](unsigned i) { return check[i] == counts[i]; }) &&
all_of(check, [&](unsigned i) { return check[i] == counts[i]; });
}
void ac_plugin::forward_simplify(unsigned src_eq, unsigned dst_eq) {
@ -540,12 +633,14 @@ namespace euf {
// dst = A -> BC
// src = B -> D
// post(dst) := A -> CD
auto& src = m_eqs[src_eq];
auto& src = m_eqs[src_eq]; // src_r_counts, src_l_counts are initialized
auto& dst = m_eqs[dst_eq];
TRACE("plugin", tout << "forward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << "\n");
if (forward_subsumes(src_eq, dst_eq)) {
TRACE("plugin", tout << "forward subsumed\n");
set_status(dst_eq, eq_status::is_dead);
return;
}
@ -553,34 +648,49 @@ namespace euf {
if (!can_be_subset(monomial(src.l), monomial(dst.r)))
return;
m_dst_r_counts.reset();
unsigned src_l_size = monomial(src.l).size();
unsigned src_r_size = m_src_r.size();
SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts));
// subtract src.l from dst.r if src.l is a subset of dst.r
// new_rhs := old_rhs - src_lhs + src_rhs
// dst_rhs := dst_rhs - src_lhs + src_rhs
// := src_rhs + (dst_rhs - src_lhs)
// := src_rhs + elements from dst_rhs that are in excess of src_lhs
unsigned num_overlap = 0;
for (auto n : monomial(dst.r)) {
unsigned id = n->root_id();
unsigned count = m_src_l_counts[id];
if (count == 0)
m_src_r.push_back(n);
else if (m_dst_r_counts[id] >= count)
unsigned dst_count = m_dst_r_counts[id];
unsigned src_count = m_src_l_counts[id];
if (dst_count > src_count) {
m_src_r.push_back(n);
m_dst_r_counts.dec(id, 1);
}
else if (dst_count < src_count) {
m_src_r.shrink(src_r_size);
return;
}
else
m_dst_r_counts.inc(id, 1), ++num_overlap;
++num_overlap;
}
// The dst.r has to be a superset of src.l, otherwise simplification does not apply
if (num_overlap == src_l_size) {
auto new_r = to_monomial(nullptr, m_src_r);
m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] });
m_eqs[dst_eq].r = new_r;
m_eqs[dst_eq].j = justify_rewrite(src_eq, dst_eq);
push_undo(is_update_eq);
TRACE("plugin", tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n");
if (num_overlap != src_l_size) {
m_src_r.shrink(src_r_size);
return;
}
m_src_r.shrink(src_r_size);
auto j = justify_rewrite(src_eq, dst_eq);
reduce(m_src_r, j);
auto new_r = to_monomial(m_src_r);
index_new_r(dst_eq, monomial(m_eqs[dst_eq].r), monomial(new_r));
m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] });
m_eqs[dst_eq].r = new_r;
m_eqs[dst_eq].j = j;
push_undo(is_update_eq);
m_src_r.reset();
m_src_r.append(monomial(src.r).m_nodes);
TRACE("plugin", tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n");
}
bool ac_plugin::backward_simplify(unsigned dst_eq, unsigned src_eq) {
@ -588,25 +698,38 @@ namespace euf {
return false;
auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq];
auto& dst = m_eqs[dst_eq]; // pre-computed dst_r_counts, dst_l_counts
//
// dst_ids, dst_count contain rhs of dst_eq
//
TRACE("plugin", tout << "backward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << "\n");
// check that src.l is a subset of dst.r
if (!can_be_subset(monomial(src.l), monomial(dst.r)))
return false;
if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l)))
return false;
if (backward_subsumes(src_eq, dst_eq)) {
TRACE("plugin", tout << "backward simplify " << eq_pp(*this, src) << " " << eq_pp(*this, dst) << " can-be-subset: " << can_be_subset(monomial(src.l), monomial(dst.r)) << "\n");
if (backward_subsumes(src_eq, dst_eq)) {
TRACE("plugin", tout << "backward subsumed\n");
set_status(dst_eq, eq_status::is_dead);
return true;
}
// dst_rhs := dst_rhs - src_lhs + src_rhs
auto new_r = rewrite(monomial(src.r), monomial(dst.r));
// check that src.l is a subset of dst.r
if (!can_be_subset(monomial(src.l), monomial(dst.r)))
return false;
if (!is_subset(m_dst_r_counts, m_src_l_counts, monomial(src.l))) {
TRACE("plugin", tout << "not subset\n");
return false;
}
SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts));
ptr_vector<node> m(m_dst_r);
init_ref_counts(monomial(src.l), m_src_l_counts);
rewrite1(m_src_l_counts, monomial(src.r), m_dst_r_counts, m);
auto j = justify_rewrite(src_eq, dst_eq);
reduce(m, j);
auto new_r = to_monomial(m);
index_new_r(dst_eq, monomial(m_eqs[dst_eq].r), monomial(new_r));
m_update_eq_trail.push_back({ dst_eq, m_eqs[dst_eq] });
m_eqs[dst_eq].r = new_r;
m_eqs[dst_eq].j = justify_rewrite(src_eq, dst_eq);
m_eqs[dst_eq].j = j;
TRACE("plugin", tout << "rewritten to " << m_pp(*this, monomial(new_r)) << "\n");
push_undo(is_update_eq);
return true;
@ -618,6 +741,8 @@ namespace euf {
bool ac_plugin::backward_subsumes(unsigned src_eq, unsigned dst_eq) {
auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq];
SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts));
SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts));
if (!can_be_subset(monomial(src.l), monomial(dst.l)))
return false;
if (!can_be_subset(monomial(src.r), monomial(dst.r)))
@ -625,13 +750,14 @@ namespace euf {
unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size();
if (size_diff != monomial(dst.r).size() - monomial(src.r).size())
return false;
if (!is_superset(m_dst_l_counts, m_src_l_counts, monomial(src.l)))
if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l)))
return false;
if (!is_superset(m_dst_r_counts, m_src_r_counts, monomial(src.r)))
return false;
// add difference betwen src and dst1 to dst2
// (also add it to dst1 to make sure same difference isn't counted twice).
for (auto n : monomial(src.l)) {
if (!is_subset(m_dst_r_counts, m_src_r_counts, monomial(src.r)))
return false;
SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts));
SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts));
// add difference betwen dst.l and src.l to both src.l, src.r
for (auto n : monomial(dst.l)) {
unsigned id = n->root_id();
SASSERT(m_dst_l_counts[id] >= m_src_l_counts[id]);
unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id];
@ -645,10 +771,12 @@ namespace euf {
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
}
// src_counts, src2_counts are initialized for src_eq
// src_l_counts, src_r_counts are initialized for src.l, src.r
bool ac_plugin::forward_subsumes(unsigned src_eq, unsigned dst_eq) {
auto& src = m_eqs[src_eq];
auto& dst = m_eqs[dst_eq];
SASSERT(is_correct_ref_count(monomial(src.l), m_src_l_counts));
SASSERT(is_correct_ref_count(monomial(src.r), m_src_r_counts));
if (!can_be_subset(monomial(src.l), monomial(dst.l)))
return false;
if (!can_be_subset(monomial(src.r), monomial(dst.r)))
@ -658,52 +786,87 @@ namespace euf {
return false;
if (!is_superset(m_src_l_counts, m_dst_l_counts, monomial(dst.l)))
return false;
if (!is_subset(m_src_r_counts, m_dst_r_counts, monomial(dst.r)))
if (!is_superset(m_src_r_counts, m_dst_r_counts, monomial(dst.r)))
return false;
SASSERT(is_correct_ref_count(monomial(dst.l), m_dst_l_counts));
SASSERT(is_correct_ref_count(monomial(dst.r), m_dst_r_counts));
for (auto n : monomial(src.l)) {
unsigned id = n->root_id();
SASSERT(m_src_l_counts[id] >= m_dst_l_counts[id]);
unsigned diff = m_src_l_counts[id] - m_dst_l_counts[id];
if (diff > 0) {
m_dst_l_counts.inc(id, diff);
m_dst_r_counts.inc(id, diff);
}
SASSERT(m_src_l_counts[id] <= m_dst_l_counts[id]);
unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id];
if (diff == 0)
continue;
m_dst_l_counts.dec(id, diff);
if (m_dst_r_counts[id] < diff)
return false;
m_dst_r_counts.dec(id, diff);
}
return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; });
}
unsigned ac_plugin::rewrite(monomial_t const& src_r, monomial_t const& dst_r) {
// pre-condition: is-subset is invoked so that m_src_count is initialized.
// pre-condition: m_dst_count is also initialized (once).
m_src_r.reset();
m_src_r.append(src_r.m_nodes);
// add to m_src_r elements of dst.r that are not in src.l
for (auto n : dst_r) {
void ac_plugin::rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_counts, ptr_vector<node>& dst) {
// pre-condition: is-subset is invoked so that src_l is initialized.
// pre-condition: dst_count is also initialized.
// remove from dst elements that are in src_l
// add elements from src_r
SASSERT(is_correct_ref_count(dst, dst_counts));
SASSERT(&src_r.m_nodes != &dst);
unsigned sz = dst.size(), j = 0;
for (unsigned i = 0; i < sz; ++i) {
auto* n = dst[i];
unsigned id = n->root_id();
unsigned count = m_src_l_counts[id];
if (count == 0)
m_src_r.push_back(n);
else
m_src_l_counts.inc(id, -1);
unsigned dst_count = dst_counts[id];
unsigned src_count = src_l[id];
SASSERT(dst_count > 0);
if (src_count == 0)
dst[j++] = n;
else if (src_count < dst_count) {
dst[j++] = n;
dst_counts.dec(id, 1);
}
}
return to_monomial(nullptr, m_src_r);
dst.shrink(j);
dst.append(src_r.m_nodes);
}
// rewrite monomial to normal form.
bool ac_plugin::reduce(ptr_vector<node>& m, justification& j) {
bool change = false;
do {
init_loop:
if (m.size() == 1)
return change;
bloom b;
init_ref_counts(m, m_m_counts);
for (auto n : m) {
for (auto eq : n->root->eqs) {
if (!is_processed(eq))
continue;
auto& src = m_eqs[eq];
if (!can_be_subset(monomial(src.l), m, b))
continue;
if (!is_subset(m_m_counts, m_eq_counts, monomial(src.l)))
continue;
TRACE("plugin", display_equation(tout << "reduce ", src) << "\n");
SASSERT(is_correct_ref_count(monomial(src.l), m_eq_counts));
rewrite1(m_eq_counts, monomial(src.r), m_m_counts, m);
j = join(j, eq);
change = true;
goto init_loop;
}
}
}
while (false);
return change;
}
// check that src is a subset of dst, where dst_counts are precomputed
bool ac_plugin::is_subset(ref_counts const& dst_counts, ref_counts& src_counts, monomial_t const& src) {
SASSERT(&dst_counts != &src_counts);
src_counts.reset();
for (auto n : src) {
unsigned id = n->root_id();
unsigned dst_count = dst_counts[id];
if (dst_count == 0)
return false;
else if (src_counts[id] >= dst_count)
return false;
else
src_counts.inc(id, 1);
}
return true;
init_ref_counts(src, src_counts);
return all_of(src_counts, [&](unsigned idx) { return src_counts[idx] <= dst_counts[idx]; });
}
// check that dst is a superset of src, where src_counts are precomputed
@ -713,6 +876,23 @@ namespace euf {
return all_of(src_counts, [&](unsigned idx) { return src_counts[idx] <= dst_counts[idx]; });
}
void ac_plugin::index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r) {
for (auto n : old_r)
n->root->n->mark1();
for (auto n : new_r)
if (!n->root->n->is_marked1()) {
n->root->eqs.push_back(eq);
m_node_trail.push_back(n->root);
n->root->n->mark1();
push_undo(is_add_eq_index);
}
for (auto n : old_r)
n->root->n->unmark1();
for (auto n : new_r)
n->root->n->unmark1();
}
void ac_plugin::superpose(unsigned src_eq, unsigned dst_eq) {
if (src_eq == dst_eq)
return;
@ -733,12 +913,20 @@ namespace euf {
// src_r contains E
// compute BE, initialize dst_ids, dst_counts
bool overlap = false;
for (auto n : monomial(dst.l)) {
unsigned id = n->root_id();
m_dst_l_counts.inc(id, 1);
if (m_src_l_counts[id] < m_dst_l_counts[id])
m_src_r.push_back(n);
m_dst_l_counts.inc(id, 1);
overlap |= m_src_l_counts[id] > 0;
}
if (!overlap) {
m_src_r.shrink(src_r_size);
return;
}
// compute CD
for (auto n : monomial(src.l)) {
unsigned id = n->root_id();
@ -753,16 +941,18 @@ namespace euf {
return;
}
TRACE("plugin", for (auto n : m_src_r) tout << g.bpp(n->n) << " "; tout << "== "; for (auto n : m_dst_r) tout << g.bpp(n->n) << " "; tout << "\n";);
TRACE("plugin", tout << m_pp(*this, m_src_r) << "== " << m_pp(*this, m_dst_r) << "\n";);
justification j = justify_rewrite(src_eq, dst_eq);
reduce(m_dst_r, j);
reduce(m_src_r, j);
if (m_src_r.size() == 1 && m_dst_r.size() == 1)
push_merge(m_src_r[0]->n, m_dst_r[0]->n, j);
else
init_equation(eq(to_monomial(nullptr, m_src_r), to_monomial(nullptr, m_dst_r), j));
init_equation(eq(to_monomial(m_src_r), to_monomial(m_dst_r), j));
m_src_r.shrink(src_r_size);
m_src_r.reset();
m_src_r.append(monomial(src.r).m_nodes);
}
bool ac_plugin::are_equal(monomial_t& a, monomial_t& b) {
@ -804,52 +994,41 @@ namespace euf {
m_monomial_table.reset();
for (auto const& s1 : m_shared) {
shared s2;
if (m_monomial_table.find(s1.m, s2)) {
if (s2.n->get_root() != s1.n->get_root())
push_merge(s1.n, s2.n, justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(s1.j), m_dep_manager.mk_leaf(s2.j))));
}
else
TRACE("plugin", tout << "shared " << m_pp(*this, monomial(s1.m)) << "\n");
if (!m_monomial_table.find(s1.m, s2))
m_monomial_table.insert(s1.m, s1);
else if (s2.n->get_root() != s1.n->get_root()) {
TRACE("plugin", tout << m_pp(*this, monomial(s1.m)) << " == " << m_pp(*this, monomial(s2.m)) << "\n");
push_merge(s1.n, s2.n, justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(s1.j), m_dep_manager.mk_leaf(s2.j))));
}
}
}
void ac_plugin::simplify_shared(unsigned idx, shared s) {
bool change = true;
while (change) {
change = false;
auto & m = monomial(s.m);
init_ref_counts(m, m_dst_l_counts);
init_subset_iterator(UINT_MAX, m);
for (auto eq : m_eq_occurs) {
auto& src = m_eqs[eq];
if (!can_be_subset(monomial(src.l), m))
continue;
if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l)))
continue;
m_update_shared_trail.push_back({ idx, s });
push_undo(is_update_shared);
unsigned new_m = rewrite(monomial(src.r), m);
m_shared[idx].m = new_m;
m_shared[idx].j = justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(s.j), justify_equation(eq)));
// update shared occurrences for members of the new monomial that are not already in the old monomial.
for (auto n : monomial(s.m))
n->root->n->mark1();
for (auto n : monomial(new_m))
if (!n->root->n->is_marked1()) {
n->root->shared.push_back(s.m);
m_shared_todo.insert(s.m);
m_node_trail.push_back(n->root);
push_undo(is_add_shared);
}
for (auto n : monomial(s.m))
n->root->n->unmark1();
auto j = s.j;
auto old_m = s.m;
ptr_vector<node> m1(monomial(old_m).m_nodes);
TRACE("plugin", tout << "simplify " << m_pp(*this, monomial(old_m)) << "\n");
if (!reduce(m1, j))
return;
s = m_shared[idx];
change = true;
break;
auto new_m = to_monomial(m1);
// update shared occurrences for members of the new monomial that are not already in the old monomial.
for (auto n : monomial(old_m))
n->root->n->mark1();
for (auto n : m1)
if (!n->root->n->is_marked1()) {
n->root->shared.push_back(idx);
m_shared_todo.insert(idx);
m_node_trail.push_back(n->root);
push_undo(is_add_shared_index);
}
}
for (auto n : monomial(old_m))
n->root->n->unmark1();
m_update_shared_trail.push_back({ idx, s });
push_undo(is_update_shared);
m_shared[idx].m = new_m;
m_shared[idx].j = j;
}
justification ac_plugin::justify_rewrite(unsigned eq1, unsigned eq2) {
@ -871,4 +1050,9 @@ namespace euf {
j = m_dep_manager.mk_join(j, m_dep_manager.mk_leaf(justification::equality(n->root->n, n->n)));
return j;
}
justification ac_plugin::join(justification j, unsigned eq) {
return justification::dependent(m_dep_manager.mk_join(m_dep_manager.mk_leaf(j), justify_equation(eq)));
}
}

View file

@ -101,6 +101,7 @@ namespace euf {
bloom m_bloom;
node* operator[](unsigned i) const { return m_nodes[i]; }
unsigned size() const { return m_nodes.size(); }
void set(ptr_vector<node> const& ns) { m_nodes.reset(); m_nodes.append(ns); m_bloom.m_tick = 0; }
node* const* begin() const { return m_nodes.begin(); }
node* const* end() const { return m_nodes.end(); }
node* * begin() { return m_nodes.begin(); }
@ -136,10 +137,12 @@ namespace euf {
}
};
unsigned m_fid;
unsigned m_op;
unsigned m_fid = 0;
unsigned m_op = null_decl_kind;
func_decl* m_decl = nullptr;
vector<eq> m_eqs;
ptr_vector<node> m_nodes;
bool_vector m_shared_nodes;
vector<monomial_t> m_monomials;
svector<shared> m_shared;
justification::dependency_manager m_dep_manager;
@ -161,7 +164,8 @@ namespace euf {
is_add_node,
is_merge_node,
is_update_eq,
is_add_shared,
is_add_shared_index,
is_add_eq_index,
is_register_shared,
is_update_shared
};
@ -177,19 +181,21 @@ namespace euf {
node* mk_node(enode* n);
void merge(node* r1, node* r2, justification j);
bool is_op(enode* n) const { auto d = n->get_decl(); return d && m_fid == d->get_family_id() && m_op == d->get_decl_kind(); }
bool is_op(enode* n) const { auto d = n->get_decl(); return d && (d == m_decl || (m_fid == d->get_family_id() && m_op == d->get_decl_kind())); }
std::function<void(void)> m_undo_notify;
void push_undo(undo_kind k);
enode_vector m_todo;
unsigned to_monomial(enode* n);
unsigned to_monomial(enode* n, ptr_vector<node> const& ms);
unsigned to_monomial(ptr_vector<node> const& ms) { return to_monomial(nullptr, ms); }
monomial_t const& monomial(unsigned i) const { return m_monomials[i]; }
monomial_t& monomial(unsigned i) { return m_monomials[i]; }
void sort(monomial_t& monomial);
bool is_sorted(monomial_t const& monomial) const;
uint64_t filter(monomial_t& m);
bool can_be_subset(monomial_t& subset, monomial_t& superset);
bool can_be_subset(monomial_t& subset, ptr_vector<node> const& m, bloom& b);
bool are_equal(ptr_vector<node> const& a, ptr_vector<node> const& b);
bool are_equal(monomial_t& a, monomial_t& b);
bool backward_subsumes(unsigned src_eq, unsigned dst_eq);
@ -216,14 +222,15 @@ namespace euf {
unsigned const* begin() const { return ids.begin(); }
unsigned const* end() const { return ids.end(); }
};
ref_counts m_src_l_counts, m_dst_l_counts, m_src_r_counts, m_dst_r_counts, m_eq_counts;
ref_counts m_src_l_counts, m_dst_l_counts, m_src_r_counts, m_dst_r_counts, m_eq_counts, m_m_counts;
unsigned_vector m_eq_occurs;
bool_vector m_eq_seen;
unsigned_vector const& forward_iterator(unsigned eq);
unsigned_vector const& superpose_iterator(unsigned eq);
unsigned_vector const& backward_iterator(unsigned eq);
void init_ref_counts(monomial_t const& monomial, ref_counts& counts);
void init_ref_counts(monomial_t const& monomial, ref_counts& counts) const;
void init_ref_counts(ptr_vector<node> const& monomial, ref_counts& counts) const;
void init_overlap_iterator(unsigned eq, monomial_t const& m);
void init_subset_iterator(unsigned eq, monomial_t const& m);
void compress_eq_occurs(unsigned eq_id);
@ -232,7 +239,9 @@ namespace euf {
// check that dst is a superset of dst, where src_counts are precomputed
bool is_superset(ref_counts const& src_counts, ref_counts& dst_counts, monomial_t const& dst);
unsigned rewrite(monomial_t const& src_r, monomial_t const& dst_r);
void rewrite1(ref_counts const& src_l, monomial_t const& src_r, ref_counts& dst_r_counts, ptr_vector<node>& dst_r);
bool reduce(ptr_vector<node>& m, justification& j);
void index_new_r(unsigned eq, monomial_t const& old_r, monomial_t const& new_r);
bool is_to_simplify(unsigned eq) const { return m_eqs[eq].status == eq_status::to_simplify; }
bool is_processed(unsigned eq) const { return m_eqs[eq].status == eq_status::processed; }
@ -241,11 +250,17 @@ namespace euf {
justification justify_rewrite(unsigned eq1, unsigned eq2);
justification::dependency* justify_equation(unsigned eq);
justification::dependency* justify_monomial(justification::dependency* d, monomial_t const& m);
justification join(justification j1, unsigned eq);
bool is_correct_ref_count(monomial_t const& m, ref_counts const& counts) const;
bool is_correct_ref_count(ptr_vector<node> const& m, ref_counts const& counts) const;
void register_shared(enode* n);
void propagate_shared();
void simplify_shared(unsigned idx, shared s);
std::ostream& display_monomial(std::ostream& out, monomial_t const& m) const;
std::ostream& display_monomial(std::ostream& out, monomial_t const& m) const { return display_monomial(out, m.m_nodes); }
std::ostream& display_monomial(std::ostream& out, ptr_vector<node> const& m) const;
std::ostream& display_equation(std::ostream& out, eq const& e) const;
std::ostream& display_status(std::ostream& out, eq_status s) const;
@ -254,17 +269,17 @@ namespace euf {
ac_plugin(egraph& g, unsigned fid, unsigned op);
ac_plugin(egraph& g, func_decl* f);
~ac_plugin() override {}
unsigned get_id() const override { return m_fid; }
void register_node(enode* n) override;
void register_shared(enode* n) override;
void merge_eh(enode* n1, enode* n2) override;
void merge_eh(enode* n1, enode* n2, justification j) override;
void diseq_eh(enode* n1, enode* n2) override {}
void diseq_eh(enode* eq) override;
void undo() override;
@ -282,8 +297,9 @@ namespace euf {
};
struct m_pp {
ac_plugin& p; monomial_t const& m;
m_pp(ac_plugin& p, monomial_t const& m) : p(p), m(m) {}
ac_plugin& p; ptr_vector<node> const& m;
m_pp(ac_plugin& p, monomial_t const& m) : p(p), m(m.m_nodes) {}
m_pp(ac_plugin& p, ptr_vector<node> const& m) : p(p), m(m) {}
std::ostream& display(std::ostream& out) const { return p.display_monomial(out, m); }
};
};

View file

@ -36,20 +36,9 @@ namespace euf {
// no-op
}
void arith_plugin::register_shared(enode* n) {
if (a.is_add(n->get_expr()))
m_add.register_shared(n);
if (a.is_mul(n->get_expr()))
m_mul.register_shared(n);
}
void arith_plugin::merge_eh(enode* n1, enode* n2, justification j) {
m_add.merge_eh(n1, n2, j);
m_mul.merge_eh(n1, n2, j);
}
void arith_plugin::diseq_eh(enode* n1, enode* n2) {
// no-op
void arith_plugin::merge_eh(enode* n1, enode* n2) {
m_add.merge_eh(n1, n2);
m_mul.merge_eh(n1, n2);
}
void arith_plugin::propagate() {

View file

@ -39,11 +39,9 @@ namespace euf {
void register_node(enode* n) override;
void register_shared(enode* n) override;
void merge_eh(enode* n1, enode* n2) override;
void merge_eh(enode* n1, enode* n2, justification j) override;
void diseq_eh(enode* n1, enode* n2) override;
void diseq_eh(enode* eq) override {}
void undo() override;

View file

@ -103,7 +103,7 @@ namespace euf {
return mk(e, 0, nullptr);
}
void bv_plugin::merge_eh(enode* x, enode* y, justification j) {
void bv_plugin::merge_eh(enode* x, enode* y) {
SASSERT(x == x->get_root());
SASSERT(x == y->get_root());
@ -120,7 +120,7 @@ namespace euf {
ys.reset();
xs.push_back(x);
ys.push_back(y);
merge(xs, ys, j);
merge(xs, ys, justification::equality(x, y));
}
// ensure p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ]

View file

@ -86,11 +86,9 @@ namespace euf {
void register_node(enode* n) override;
void register_shared(enode* n) override {}
void merge_eh(enode* n1, enode* n2) override;
void merge_eh(enode* n1, enode* n2, justification j) override;
void diseq_eh(enode* n1, enode* n2) override {}
void diseq_eh(enode* eq) override {}
void propagate() override {}

View file

@ -20,6 +20,7 @@ Notes:
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_bv_plugin.h"
#include "ast/euf/euf_arith_plugin.h"
#include "ast/euf/euf_specrel_plugin.h"
#include "ast/ast_pp.h"
#include "ast/ast_translation.h"
@ -115,7 +116,6 @@ namespace euf {
n->mark_interpreted();
if (m_on_make)
m_on_make(n);
register_node(n);
if (num_args == 0)
return n;
@ -134,22 +134,6 @@ namespace euf {
return n;
}
void egraph::register_node(enode* n) {
if (m_plugins.empty())
return;
auto* p = get_plugin(n);
if (p)
p->register_node(n);
if (!n->is_equality()) {
for (auto* arg : enode_args(n)) {
auto* p_arg = get_plugin(arg);
if (p != p_arg)
p_arg->register_shared(arg);
}
}
}
egraph::egraph(ast_manager& m) : m(m), m_table(m), m_tmp_app(2), m_exprs(m), m_eq_decls(m) {
m_tmp_eq = enode::mk_tmp(m_region, 2);
}
@ -162,6 +146,9 @@ namespace euf {
}
void egraph::add_plugins() {
if (!m_plugins.empty())
return;
auto insert = [&](plugin* p) {
m_plugins.reserve(p->get_id() + 1);
m_plugins.set(p->get_id(), p);
@ -169,6 +156,7 @@ namespace euf {
insert(alloc(bv_plugin, *this));
insert(alloc(arith_plugin, *this));
insert(alloc(specrel_plugin, *this));
}
void egraph::propagate_plugins() {
@ -182,14 +170,20 @@ namespace euf {
m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r));
m_updates.push_back(update_record(update_record::new_th_eq()));
++m_stats.m_num_th_eqs;
auto* p = get_plugin(id);
if (p)
p->merge_eh(c, r);
}
void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq) {
void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq) {
if (!th_propagates_diseqs(id))
return;
TRACE("euf_verbose", tout << "eq: " << v1 << " != " << v2 << "\n";);
m_new_th_eqs.push_back(th_eq(id, v1, v2, eq));
m_new_th_eqs.push_back(th_eq(id, v1, v2, eq->get_expr()));
m_updates.push_back(update_record(update_record::new_th_eq()));
auto* p = get_plugin(id);
if (p)
p->diseq_eh(eq);
++m_stats.m_num_th_diseqs;
}
@ -238,7 +232,7 @@ namespace euf {
return;
theory_var v1 = arg1->get_closest_th_var(id);
theory_var v2 = arg2->get_closest_th_var(id);
add_th_diseq(id, v1, v2, n->get_expr());
add_th_diseq(id, v1, v2, n);
return;
}
for (auto const& p : euf::enode_th_vars(r1)) {
@ -246,7 +240,7 @@ namespace euf {
continue;
for (auto const& q : euf::enode_th_vars(r2))
if (p.get_id() == q.get_id())
add_th_diseq(p.get_id(), p.get_var(), q.get_var(), n->get_expr());
add_th_diseq(p.get_id(), p.get_var(), q.get_var(), n);
}
}
@ -266,7 +260,7 @@ namespace euf {
n = n->get_root();
theory_var v2 = n->get_closest_th_var(id);
if (v2 != null_theory_var)
add_th_diseq(id, v1, v2, p->get_expr());
add_th_diseq(id, v1, v2, p);
}
}
}
@ -285,6 +279,10 @@ namespace euf {
theory_var w = n->get_th_var(id);
enode* r = n->get_root();
auto* p = get_plugin(id);
if (p)
p->register_node(n);
if (w == null_theory_var) {
n->add_th_var(v, id, m_region);
m_updates.push_back(update_record(n, id, update_record::add_th_var()));
@ -529,10 +527,7 @@ namespace euf {
for (auto& cb : m_on_merge)
cb(r2, r1);
auto* p = get_plugin(r1);
if (p)
p->merge_eh(r2, r1, j);
}
void egraph::remove_parents(enode* r) {

View file

@ -215,8 +215,6 @@ namespace euf {
// plugin related methods
void push_plugin_undo(unsigned th_id) { m_updates.push_back(update_record(th_id, update_record::plugin_undo())); }
void push_merge(enode* a, enode* b, justification j) { m_to_merge.push_back({ a, b, j }); }
plugin* get_plugin(enode* n) { return m_plugins.get(n->get_sort()->get_family_id(), nullptr); }
void register_node(enode* n);
void propagate_plugins();
void add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r);
@ -259,6 +257,7 @@ namespace euf {
egraph(ast_manager& m);
~egraph();
void add_plugins();
plugin* get_plugin(family_id fid) const { return m_plugins.get(fid, nullptr); }
enode* find(expr* f) const { return m_expr2enode.get(f->get_id(), nullptr); }
enode* find(expr* f, unsigned n, enode* const* args);
enode* mk(expr* f, unsigned generation, unsigned n, enode *const* args);
@ -302,7 +301,7 @@ namespace euf {
where \c n is an enode and \c is_eq indicates whether the enode
is an equality consequence.
*/
void add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq);
void add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq);
bool has_th_eq() const { return m_new_th_eqs_qhead < m_new_th_eqs.size(); }
th_eq get_th_eq() const { return m_new_th_eqs[m_new_th_eqs_qhead]; }
void next_th_eq() { force_push(); SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; }

View file

@ -93,6 +93,17 @@ namespace euf {
return null_theory_var;
}
enode* enode::get_closest_th_node(theory_id id) {
enode* n = this;
while (n) {
theory_var v = n->get_th_var(id);
if (v != null_theory_var)
return n;
n = n->m_target;
}
return nullptr;
}
bool enode::acyclic() const {
enode const* n = this;
enode const* p = this;

View file

@ -229,6 +229,7 @@ namespace euf {
theory_var get_th_var(theory_id id) const { return m_th_vars.find(id); }
theory_var get_closest_th_var(theory_id id) const;
enode* get_closest_th_node(theory_id id);
bool is_attached_to(theory_id id) const { return get_th_var(id) != null_theory_var; }
bool has_th_vars() const { return !m_th_vars.empty(); }
bool has_one_th_var() const { return !m_th_vars.empty() && !m_th_vars.get_next();}

View file

@ -0,0 +1,54 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
euf_justification.cpp
Abstract:
justification structure for euf
Author:
Nikolaj Bjorner (nbjorner) 2020-08-23
--*/
#include "ast/euf/euf_justification.h"
#include "ast/euf/euf_enode.h"
namespace euf {
std::ostream& justification::display(std::ostream& out, std::function<void (std::ostream&, void*)> const& ext) const {
switch (m_kind) {
case kind_t::external_t:
if (ext)
ext(out, m_external);
else
out << "external";
return out;
case kind_t::axiom_t:
return out << "axiom";
case kind_t::congruence_t:
return out << "congruence";
case kind_t::dependent_t: {
vector<justification, false> js;
out << "dependent";
for (auto const& j : dependency_manager::s_linearize(m_dependency, js))
j.display(out << " ", ext);
return out;
}
case kind_t::equality_t:
return out << "equality #" << m_n1->get_id() << " == #" << m_n2->get_id();
default:
UNREACHABLE();
return out;
}
return out;
}
}

View file

@ -118,31 +118,8 @@ namespace euf {
}
}
std::ostream& display(std::ostream& out, std::function<void (std::ostream&, void*)> const& ext) const {
switch (m_kind) {
case kind_t::external_t:
if (ext)
ext(out, m_external);
else
out << "external";
return out;
case kind_t::axiom_t:
return out << "axiom";
case kind_t::congruence_t:
return out << "congruence";
case kind_t::dependent_t: {
vector<justification, false> js;
out << "dependent";
for (auto const& j : dependency_manager::s_linearize(m_dependency, js))
j.display(out << " ", ext);
return out;
}
default:
UNREACHABLE();
return out;
}
return out;
}
std::ostream& display(std::ostream& out, std::function<void(std::ostream&, void*)> const& ext) const;
};
inline std::ostream& operator<<(std::ostream& out, justification const& j) {

View file

@ -43,12 +43,10 @@ namespace euf {
virtual unsigned get_id() const = 0;
virtual void register_node(enode* n) = 0;
virtual void register_shared(enode* n) = 0;
virtual void merge_eh(enode* n1, enode* n2, justification j) = 0;
virtual void merge_eh(enode* n1, enode* n2) = 0;
virtual void diseq_eh(enode* n1, enode* n2) = 0;
virtual void diseq_eh(enode* eq) {};
virtual void propagate() = 0;

View file

@ -0,0 +1,71 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_specrel_plugin.cpp
Abstract:
plugin structure for specrel
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
--*/
#include "ast/euf/euf_specrel_plugin.h"
#include "ast/euf/euf_egraph.h"
#include <algorithm>
namespace euf {
specrel_plugin::specrel_plugin(egraph& g) :
plugin(g),
sp(g.get_manager()) {
}
void specrel_plugin::register_node(enode* n) {
func_decl* f = n->get_decl();
if (!f)
return;
if (!sp.is_ac(f))
return;
ac_plugin* p = nullptr;
if (!m_decl2plugin.find(f, p)) {
p = alloc(ac_plugin, g, f);
m_decl2plugin.insert(f, p);
m_plugins.push_back(p);
std::function<void(void)> undo_op = [&]() { m_undo.push_back(p); };
p->set_undo(undo_op);
}
}
void specrel_plugin::merge_eh(enode* n1, enode* n2) {
for (auto * p : m_plugins)
p->merge_eh(n1, n2);
}
void specrel_plugin::diseq_eh(enode* eq) {
for (auto* p : m_plugins)
p->diseq_eh(eq);
}
void specrel_plugin::propagate() {
for (auto * p : m_plugins)
p->propagate();
}
void specrel_plugin::undo() {
auto p = m_undo.back();
m_undo.pop_back();
p->undo();
}
std::ostream& specrel_plugin::display(std::ostream& out) const {
for (auto * p : m_plugins)
p->display(out);
return out;
}
}

View file

@ -0,0 +1,56 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
euf_specrel_plugin.h
Abstract:
plugin structure for specrel functions
Author:
Nikolaj Bjorner (nbjorner) 2023-11-11
--*/
#pragma once
#include <iostream>
#include "util/scoped_ptr_vector.h"
#include "ast/special_relations_decl_plugin.h"
#include "ast/euf/euf_plugin.h"
#include "ast/euf/euf_ac_plugin.h"
namespace euf {
class specrel_plugin : public plugin {
scoped_ptr_vector<ac_plugin> m_plugins;
ptr_vector<ac_plugin> m_undo;
obj_map<func_decl, ac_plugin*> m_decl2plugin;
special_relations_util sp;
public:
specrel_plugin(egraph& g);
~specrel_plugin() override {}
unsigned get_id() const override { return sp.get_family_id(); }
void register_node(enode* n) override;
void merge_eh(enode* n1, enode* n2) override;
void diseq_eh(enode* eq) override;
void undo() override;
void propagate() override;
std::ostream& display(std::ostream& out) const override;
};
}

View file

@ -26,7 +26,8 @@ special_relations_decl_plugin::special_relations_decl_plugin():
m_po("partial-order"),
m_plo("piecewise-linear-order"),
m_to("tree-order"),
m_tc("transitive-closure")
m_tc("transitive-closure"),
m_ac("ac-op")
{}
func_decl * special_relations_decl_plugin::mk_func_decl(
@ -41,24 +42,53 @@ func_decl * special_relations_decl_plugin::mk_func_decl(
m_manager->raise_exception("argument sort missmatch. The two arguments should have the same sort");
return nullptr;
}
if (!range && k == OP_SPECIAL_RELATION_AC)
range = domain[0];
if (!range) {
range = m_manager->mk_bool_sort();
}
if (!m_manager->is_bool(range)) {
m_manager->raise_exception("range type is expected to be Boolean for special relations");
}
auto check_bool_range = [&]() {
if (!m_manager->is_bool(range))
m_manager->raise_exception("range type is expected to be Boolean for special relations");
};
m_has_special_relation = true;
func_decl_info info(m_family_id, k, num_parameters, parameters);
symbol name;
switch(k) {
case OP_SPECIAL_RELATION_PO: name = m_po; break;
case OP_SPECIAL_RELATION_LO: name = m_lo; break;
case OP_SPECIAL_RELATION_PLO: name = m_plo; break;
case OP_SPECIAL_RELATION_TO: name = m_to; break;
case OP_SPECIAL_RELATION_PO: check_bool_range(); name = m_po; break;
case OP_SPECIAL_RELATION_LO: check_bool_range(); name = m_lo; break;
case OP_SPECIAL_RELATION_PLO: check_bool_range(); name = m_plo; break;
case OP_SPECIAL_RELATION_TO: check_bool_range(); name = m_to; break;
case OP_SPECIAL_RELATION_AC: {
if (range != domain[0])
m_manager->raise_exception("AC operation should have the same range as domain type");
name = m_ac;
if (num_parameters != 1 || !parameters[0].is_ast() || !is_func_decl(parameters[0].get_ast()))
m_manager->raise_exception("parameter to transitive closure should be a function declaration");
func_decl* f = to_func_decl(parameters[0].get_ast());
if (f->get_arity() != 2)
m_manager->raise_exception("ac function should be binary");
if (f->get_domain(0) != f->get_domain(1))
m_manager->raise_exception("ac function should have same domain");
if (f->get_domain(0) != f->get_range())
m_manager->raise_exception("ac function should have same domain and range");
break;
}
case OP_SPECIAL_RELATION_TC:
check_bool_range();
name = m_tc;
if (num_parameters != 1 || !parameters[0].is_ast() || !is_func_decl(parameters[0].get_ast()))
m_manager->raise_exception("parameter to transitive closure should be a function declaration");
func_decl* f = to_func_decl(parameters[0].get_ast());
if (f->get_arity() != 2)
m_manager->raise_exception("tc relation should be binary");
if (f->get_domain(0) != f->get_domain(1))
m_manager->raise_exception("tc relation should have same domain");
if (!m_manager->is_bool(f->get_range()))
m_manager->raise_exception("tc relation should be Boolean");
break;
}
return m_manager->mk_func_decl(name, arity, domain, range, info);
@ -71,6 +101,7 @@ void special_relations_decl_plugin::get_op_names(svector<builtin_name> & op_name
op_names.push_back(builtin_name(m_plo.str(), OP_SPECIAL_RELATION_PLO));
op_names.push_back(builtin_name(m_to.str(), OP_SPECIAL_RELATION_TO));
op_names.push_back(builtin_name(m_tc.str(), OP_SPECIAL_RELATION_TC));
op_names.push_back(builtin_name(m_ac.str(), OP_SPECIAL_RELATION_AC));
}
}
@ -81,6 +112,7 @@ sr_property special_relations_util::get_property(func_decl* f) const {
case OP_SPECIAL_RELATION_PLO: return sr_plo;
case OP_SPECIAL_RELATION_TO: return sr_to;
case OP_SPECIAL_RELATION_TC: return sr_tc;
case OP_SPECIAL_RELATION_AC: return sr_none;
default:
UNREACHABLE();
return sr_po;

View file

@ -16,6 +16,8 @@ Author:
Revision History:
2023-11-27: Added ac-op for E-graph plugin
--*/
#pragma once
@ -28,6 +30,7 @@ enum special_relations_op_kind {
OP_SPECIAL_RELATION_PLO,
OP_SPECIAL_RELATION_TO,
OP_SPECIAL_RELATION_TC,
OP_SPECIAL_RELATION_AC,
LAST_SPECIAL_RELATIONS_OP
};
@ -37,6 +40,7 @@ class special_relations_decl_plugin : public decl_plugin {
symbol m_plo;
symbol m_to;
symbol m_tc;
symbol m_ac;
bool m_has_special_relation = false;
public:
special_relations_decl_plugin();
@ -86,13 +90,16 @@ class special_relations_util {
public:
special_relations_util(ast_manager& m) : m(m), m_fid(null_family_id) { }
family_id get_family_id() const { return fid(); }
bool has_special_relation() const { return static_cast<special_relations_decl_plugin*>(m.get_plugin(m.mk_family_id("specrels")))->has_special_relation(); }
bool is_special_relation(func_decl* f) const { return f->get_family_id() == fid(); }
bool is_special_relation(app* e) const { return is_special_relation(e->get_decl()); }
bool is_special_relation(expr* e) const { return is_app(e) && is_special_relation(to_app(e)->get_decl()); }
sr_property get_property(func_decl* f) const;
sr_property get_property(app* e) const { return get_property(e->get_decl()); }
func_decl* get_relation(func_decl* f) const { SASSERT(is_special_relation(f)); return to_func_decl(f->get_parameter(0).get_ast()); }
func_decl* get_relation(expr* e) const { SASSERT(is_special_relation(e)); return to_func_decl(to_app(e)->get_parameter(0).get_ast()); }
func_decl* mk_to_decl(func_decl* f) { return mk_rel_decl(f, OP_SPECIAL_RELATION_TO); }
func_decl* mk_po_decl(func_decl* f) { return mk_rel_decl(f, OP_SPECIAL_RELATION_PO); }
@ -105,12 +112,14 @@ public:
bool is_plo(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_PLO); }
bool is_to(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_TO); }
bool is_tc(expr const * e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_TC); }
bool is_ac(expr const* e) const { return is_app_of(e, fid(), OP_SPECIAL_RELATION_AC); }
bool is_lo(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_LO); }
bool is_po(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_PO); }
bool is_plo(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_PLO); }
bool is_to(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_TO); }
bool is_tc(func_decl const * e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_TC); }
bool is_ac(func_decl const* e) const { return is_decl_of(e, fid(), OP_SPECIAL_RELATION_AC); }
app * mk_lo (expr * arg1, expr * arg2) { return m.mk_app( fid(), OP_SPECIAL_RELATION_LO, arg1, arg2); }
app * mk_po (expr * arg1, expr * arg2) { return m.mk_app( fid(), OP_SPECIAL_RELATION_PO, arg1, arg2); }

View file

@ -45,6 +45,7 @@ z3_add_component(sat_smt
q_solver.cpp
recfun_solver.cpp
sat_th.cpp
specrel_solver.cpp
tseitin_theory_checker.cpp
user_solver.cpp
COMPONENT_DEPENDENCIES

View file

@ -7,7 +7,7 @@ Module Name:
Abstract:
Theory plugin for altegraic datatypes
Theory plugin for algebraic datatypes
Author:

View file

@ -7,7 +7,7 @@ Module Name:
Abstract:
Theory plugin for altegraic datatypes
Theory plugin for algebraic datatypes
Author:

View file

@ -28,6 +28,7 @@ Author:
#include "sat/smt/fpa_solver.h"
#include "sat/smt/dt_solver.h"
#include "sat/smt/recfun_solver.h"
#include "sat/smt/specrel_solver.h"
namespace euf {
@ -130,6 +131,7 @@ namespace euf {
arith_util arith(m);
datatype_util dt(m);
recfun::util rf(m);
special_relations_util sp(m);
if (pb.get_family_id() == fid)
ext = alloc(pb::solver, *this, fid);
else if (bvu.get_family_id() == fid)
@ -144,7 +146,8 @@ namespace euf {
ext = alloc(dt::solver, *this, fid);
else if (rf.get_family_id() == fid)
ext = alloc(recfun::solver, *this);
else if (sp.get_family_id() == fid)
ext = alloc(specrel::solver, *this, fid);
if (ext)
add_solver(ext);
else if (f)

View file

@ -0,0 +1,119 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
specrel_solver.h
Abstract:
Theory plugin for special relations
Author:
Nikolaj Bjorner (nbjorner) 2020-09-08
--*/
#include "sat/smt/specrel_solver.h"
#include "sat/smt/euf_solver.h"
namespace euf {
class solver;
}
namespace specrel {
solver::solver(euf::solver& ctx, theory_id id) :
th_euf_solver(ctx, ctx.get_manager().get_family_name(id), id),
sp(m)
{
ctx.get_egraph().add_plugins();
}
solver::~solver() {
}
void solver::asserted(sat::literal l) {
}
sat::check_result solver::check() {
return sat::check_result::CR_DONE;
}
std::ostream& solver::display(std::ostream& out) const {
return out;
}
void solver::collect_statistics(statistics& st) const {
}
euf::th_solver* solver::clone(euf::solver& ctx) {
return alloc(solver, ctx, get_id());
}
void solver::new_eq_eh(euf::th_eq const& eq) {
TRACE("specrel", tout << "new-eq\n");
if (eq.is_eq()) {
auto* p = ctx.get_egraph().get_plugin(sp.get_family_id());
p->merge_eh(var2enode(eq.v1()), var2enode(eq.v2()));
TRACE("specrel", tout << eq.v1() << " " << eq.v2() << "\n");
}
}
void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) {
}
bool solver::add_dep(euf::enode* n, top_sort<euf::enode>& dep) {
return false;
}
bool solver::include_func_interp(func_decl* f) const {
return false;
}
sat::literal solver::internalize(expr* e, bool sign, bool root) {
if (!visit_rec(m, e, sign, root))
return sat::null_literal;
auto lit = ctx.expr2literal(e);
if (sign)
lit.neg();
return lit;
}
void solver::internalize(expr* e) {
visit_rec(m, e, false, false);
}
bool solver::visit(expr* e) {
if (visited(e))
return true;
m_stack.push_back(sat::eframe(e));
return false;
}
bool solver::visited(expr* e) {
euf::enode* n = expr2enode(e);
return n && n->is_attached_to(get_id());
}
bool solver::post_visit(expr* term, bool sign, bool root) {
euf::enode* n = expr2enode(term);
SASSERT(!n || !n->is_attached_to(get_id()));
if (!n)
n = mk_enode(term);
SASSERT(!n->is_attached_to(get_id()));
mk_var(n);
TRACE("specrel", tout << ctx.bpp(n) << "\n");
return true;
}
euf::theory_var solver::mk_var(euf::enode* n) {
if (is_attached_to_var(n))
return n->get_th_var(get_id());
euf::theory_var r = th_euf_solver::mk_var(n);
ctx.attach_th_var(n, this, r);
return r;
}
}

View file

@ -0,0 +1,75 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
specrel_solver.h
Abstract:
Theory plugin for special relations
Author:
Nikolaj Bjorner (nbjorner) 2020-09-08
--*/
#pragma once
#include "sat/smt/sat_th.h"
#include "ast/special_relations_decl_plugin.h"
namespace euf {
class solver;
}
namespace specrel {
class solver : public euf::th_euf_solver {
typedef euf::theory_var theory_var;
typedef euf::theory_id theory_id;
typedef euf::enode enode;
typedef euf::enode_pair enode_pair;
typedef euf::enode_pair_vector enode_pair_vector;
typedef sat::bool_var bool_var;
typedef sat::literal literal;
typedef sat::literal_vector literal_vector;
special_relations_util sp;
public:
solver(euf::solver& ctx, theory_id id);
~solver() override;
bool is_external(bool_var v) override { return false; }
void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) override {}
void asserted(literal l) override;
sat::check_result check() override;
std::ostream& display(std::ostream& out) const override;
std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { return euf::th_explain::from_index(idx).display(out); }
std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { return display_justification(out, idx); }
void collect_statistics(statistics& st) const override;
euf::th_solver* clone(euf::solver& ctx) override;
void new_eq_eh(euf::th_eq const& eq) override;
bool unit_propagate() override { return false; }
void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override;
bool add_dep(euf::enode* n, top_sort<euf::enode>& dep) override;
bool include_func_interp(func_decl* f) const override;
sat::literal internalize(expr* e, bool sign, bool root) override;
void internalize(expr* e) override;
bool visit(expr* e) override;
bool visited(expr* e) override;
bool post_visit(expr* e, bool sign, bool root) override;
euf::theory_var mk_var(euf::enode* n) override;
void apply_sort_cnstr(euf::enode* n, sort* s) override {}
bool is_shared(theory_var v) const override { return false; }
lbool get_phase(bool_var v) override { return l_true; }
bool enable_self_propagate() const override { return true; }
void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2);
void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) {}
void unmerge_eh(theory_var v1, theory_var v2) {}
};
}

View file

@ -11,14 +11,18 @@ Copyright (c) 2023 Microsoft Corporation
#include "ast/ast_pp.h"
#include <iostream>
static euf::enode* get_node(euf::egraph& g, expr* e) {
unsigned s_var = 0;
static euf::enode* get_node(euf::egraph& g, arith_util& a, expr* e) {
auto* n = g.find(e);
if (n)
return n;
euf::enode_vector args;
for (expr* arg : *to_app(e))
args.push_back(get_node(g, arg));
return g.mk(e, 0, args.size(), args.data());
args.push_back(get_node(g, a, arg));
n = g.mk(e, 0, args.size(), args.data());
g.add_th_var(n, s_var++, a.get_family_id());
return n;
}
//
@ -32,15 +36,15 @@ static void test1() {
expr_ref x(m.mk_const("x", I), m);
expr_ref y(m.mk_const("y", I), m);
auto* nx = get_node(g, a.mk_add(a.mk_add(y, y), a.mk_add(x, x)));
auto* ny = get_node(g, a.mk_add(a.mk_add(y, x), x));
auto* nx = get_node(g, a, a.mk_add(a.mk_add(y, y), a.mk_add(x, x)));
auto* ny = get_node(g, a, a.mk_add(a.mk_add(y, x), x));
TRACE("plugin", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("plugin", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("plugin", tout << "after propagate\n" << g << "\n");
g.merge(get_node(g, a.mk_add(x, a.mk_add(y, y))), get_node(g, a.mk_add(y, x)), nullptr);
g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr);
g.propagate();
std::cout << g << "\n";
}
@ -55,10 +59,10 @@ static void test2() {
expr_ref x(m.mk_const("x", I), m);
expr_ref y(m.mk_const("y", I), m);
auto* nxy = get_node(g, a.mk_add(x, y));
auto* nyx = get_node(g, a.mk_add(y, x));
auto* nx = get_node(g, x);
auto* ny = get_node(g, y);
auto* nxy = get_node(g, a, a.mk_add(x, y));
auto* nyx = get_node(g, a, a.mk_add(y, x));
auto* nx = get_node(g, a, x);
auto* ny = get_node(g, a, y);
TRACE("plugin", tout << "before merge\n" << g << "\n");
g.merge(nxy, nx, nullptr);
@ -67,7 +71,7 @@ static void test2() {
g.propagate();
TRACE("plugin", tout << "after propagate\n" << g << "\n");
SASSERT(nx->get_root() == ny->get_root());
g.merge(get_node(g, a.mk_add(x, a.mk_add(y, y))), get_node(g, a.mk_add(y, x)), nullptr);
g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr);
g.propagate();
std::cout << g << "\n";
}
@ -82,22 +86,21 @@ static void test3() {
expr_ref x(m.mk_const("x", I), m);
expr_ref y(m.mk_const("y", I), m);
auto* nxyy = get_node(g, a.mk_add(a.mk_add(x, y), y));
auto* nyxx = get_node(g, a.mk_add(a.mk_add(y, x), x));
auto* nx = get_node(g, x);
auto* ny = get_node(g, y);
auto* nxyy = get_node(g, a, a.mk_add(a.mk_add(x, y), y));
auto* nyxx = get_node(g, a, a.mk_add(a.mk_add(y, x), x));
auto* nx = get_node(g, a, x);
auto* ny = get_node(g, a, y);
g.merge(nxyy, nx, nullptr);
g.merge(nyxx, ny, nullptr);
TRACE("plugin", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("plugin", tout << "after propagate\n" << g << "\n");
SASSERT(nx->get_root() == ny->get_root());
std::cout << g << "\n";
}
void tst_euf_arith_plugin() {
enable_trace("plugin");
test3();
test1();
test2();
test3();
}

View file

@ -11,14 +11,17 @@ Copyright (c) 2023 Microsoft Corporation
#include "ast/ast_pp.h"
#include <iostream>
static euf::enode* get_node(euf::egraph& g, expr* e) {
static unsigned s_var = 0;
static euf::enode* get_node(euf::egraph& g, bv_util& b, expr* e) {
auto* n = g.find(e);
if (n)
return n;
euf::enode_vector args;
for (expr* arg : *to_app(e))
args.push_back(get_node(g, arg));
return g.mk(e, 0, args.size(), args.data());
args.push_back(get_node(g, b, arg));
n = g.mk(e, 0, args.size(), args.data());
g.add_th_var(n, s_var++, b.get_family_id());
return n;
}
// align slices, and propagate extensionality
@ -40,8 +43,8 @@ static void test1() {
expr_ref y1(bv.mk_extract(7, 0, y), m);
expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m);
expr_ref yy(bv.mk_concat(y1, bv.mk_concat(y2, y3)), m);
auto* nx = get_node(g, xx);
auto* ny = get_node(g, yy);
auto* nx = get_node(g, bv, xx);
auto* ny = get_node(g, bv, yy);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");
@ -65,12 +68,12 @@ static void test2() {
expr_ref x2(bv.mk_extract(15, 8, x), m);
expr_ref x1(bv.mk_extract(7, 0, x), m);
expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m);
g.merge(get_node(g, xx), get_node(g, bv.mk_numeral((1 << 27) + (1 << 17) + (1 << 3), 32)), nullptr);
g.merge(get_node(g, bv, xx), get_node(g, bv, bv.mk_numeral((1 << 27) + (1 << 17) + (1 << 3), 32)), nullptr);
g.propagate();
SASSERT(get_node(g, x1)->get_root()->interpreted());
SASSERT(get_node(g, x2)->get_root()->interpreted());
SASSERT(get_node(g, x3)->get_root()->interpreted());
SASSERT(get_node(g, x)->get_root()->interpreted());
SASSERT(get_node(g, bv, x1)->get_root()->interpreted());
SASSERT(get_node(g, bv, x2)->get_root()->interpreted());
SASSERT(get_node(g, bv, x3)->get_root()->interpreted());
SASSERT(get_node(g, bv, x)->get_root()->interpreted());
}
@ -89,13 +92,13 @@ static void test3() {
expr_ref x1(bv.mk_extract(7, 0, x), m);
expr_ref xx(bv.mk_concat(bv.mk_concat(x1, x2), x3), m);
expr_ref y(m.mk_const("y", u32), m);
g.merge(get_node(g, xx), get_node(g, y), nullptr);
g.merge(get_node(g, x1), get_node(g, bv.mk_numeral(2, 8)), nullptr);
g.merge(get_node(g, x2), get_node(g, bv.mk_numeral(8, 8)), nullptr);
g.merge(get_node(g, bv, xx), get_node(g, bv, y), nullptr);
g.merge(get_node(g, bv, x1), get_node(g, bv, bv.mk_numeral(2, 8)), nullptr);
g.merge(get_node(g, bv, x2), get_node(g, bv, bv.mk_numeral(8, 8)), nullptr);
g.propagate();
SASSERT(get_node(g, bv.mk_concat(x1, x2))->get_root()->interpreted());
SASSERT(get_node(g, x1)->get_root()->interpreted());
SASSERT(get_node(g, x2)->get_root()->interpreted());
SASSERT(get_node(g, bv, bv.mk_concat(x1, x2))->get_root()->interpreted());
SASSERT(get_node(g, bv, x1)->get_root()->interpreted());
SASSERT(get_node(g, bv, x2)->get_root()->interpreted());
}
// propagate extract up
@ -114,11 +117,11 @@ static void test4() {
expr_ref y(m.mk_const("y", u16), m);
expr_ref x1(bv.mk_extract(15, 8, x), m);
expr_ref x2(bv.mk_extract(23, 16, x), m);
g.merge(get_node(g, bv.mk_concat(a, x2)), get_node(g, y), nullptr);
g.merge(get_node(g, x1), get_node(g, a), nullptr);
g.merge(get_node(g, bv, bv.mk_concat(a, x2)), get_node(g, bv, y), nullptr);
g.merge(get_node(g, bv, x1), get_node(g, bv, a), nullptr);
g.propagate();
TRACE("bv", tout << g << "\n");
SASSERT(get_node(g, bv.mk_extract(23, 8, x))->get_root() == get_node(g, y)->get_root());
SASSERT(get_node(g, bv, bv.mk_extract(23, 8, x))->get_root() == get_node(g, bv, y)->get_root());
}
// iterative slicing
@ -133,8 +136,8 @@ static void test5() {
expr_ref x(m.mk_const("x", u32), m);
expr_ref x1(bv.mk_extract(31, 4, x), m);
expr_ref x2(bv.mk_extract(27, 0, x), m);
auto* nx = get_node(g, x1);
auto* ny = get_node(g, x2);
auto* nx = get_node(g, bv, x1);
auto* ny = get_node(g, bv, x2);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");
@ -155,8 +158,8 @@ static void test6() {
expr_ref x(m.mk_const("x", u32), m);
expr_ref x1(bv.mk_extract(31, 3, x), m);
expr_ref x2(bv.mk_extract(28, 0, x), m);
auto* nx = get_node(g, x1);
auto* ny = get_node(g, x2);
auto* nx = get_node(g, bv, x1);
auto* ny = get_node(g, bv, x2);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");