3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-24 14:53:40 +00:00

port updates to egraph from poly

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2024-02-03 12:48:58 -08:00
parent 24ffef8ac5
commit a5a819c291
11 changed files with 91 additions and 54 deletions

View file

@ -109,7 +109,7 @@ namespace euf {
m_shared_nodes.setx(n->get_id(), true, false); m_shared_nodes.setx(n->get_id(), true, false);
sort(monomial(m)); sort(monomial(m));
m_shared_todo.insert(m_shared.size()); m_shared_todo.insert(m_shared.size());
m_shared.push_back({ n, m, justification::axiom() }); m_shared.push_back({ n, m, justification::axiom(get_id()) });
push_undo(is_register_shared); push_undo(is_register_shared);
} }

View file

@ -137,7 +137,7 @@ namespace euf {
} }
}; };
unsigned m_fid = 0; theory_id m_fid = 0;
unsigned m_op = null_decl_kind; unsigned m_op = null_decl_kind;
func_decl* m_decl = nullptr; func_decl* m_decl = nullptr;
vector<eq> m_eqs; vector<eq> m_eqs;
@ -273,7 +273,7 @@ namespace euf {
~ac_plugin() override {} ~ac_plugin() override {}
unsigned get_id() const override { return m_fid; } theory_id get_id() const override { return m_fid; }
void register_node(enode* n) override; void register_node(enode* n) override;

View file

@ -35,7 +35,7 @@ namespace euf {
~arith_plugin() override {} ~arith_plugin() override {}
unsigned get_id() const override { return a.get_family_id(); } theory_id get_id() const override { return a.get_family_id(); }
void register_node(enode* n) override; void register_node(enode* n) override;

View file

@ -162,19 +162,26 @@ namespace euf {
if (!is_value(x)) if (!is_value(x))
return; return;
auto val_x = get_value(x);
enode* a, * b; enode* a, * b;
for (enode* p : enode_parents(x)) unsigned lo, hi;
if (is_concat(p, a, b) && is_value(a) && is_value(b) && !is_value(p)) for (enode* p : enode_parents(x)) {
if (is_concat(p, a, b) && is_value(a) && is_value(b))
push_merge(mk_concat(a->get_interpreted(), b->get_interpreted()), mk_value_concat(a, b)); push_merge(mk_concat(a->get_interpreted(), b->get_interpreted()), mk_value_concat(a, b));
if (is_extract(p, lo, hi)) {
auto val_p = mod2k(machine_div2k(val_x, lo), hi - lo + 1);
auto ix = x->get_interpreted();
auto ex = mk(bv.mk_extract(hi, lo, ix->get_expr()), 1, &ix);
push_merge(ex, mk_value(val_p, width(p)));
}
}
for (enode* sib : enode_class(x)) { for (enode* sib : enode_class(x)) {
if (is_concat(sib, a, b)) { if (is_concat(sib, a, b)) {
if (!is_value(a) || !is_value(b)) { auto val_a = machine_div2k(val_x, width(b));
auto val = get_value(x); auto val_b = mod2k(val_x, width(b));
auto val_a = machine_div2k(val, width(b)); push_merge(mk_concat(mk_value(val_a, width(a)), mk_value(val_b, width(b))), x->get_interpreted());
auto val_b = mod2k(val, width(b));
push_merge(mk_concat(mk_value(val_a, width(a)), mk_value(val_b, width(b))), x->get_interpreted());
}
} }
} }
} }
@ -198,7 +205,9 @@ namespace euf {
enode* arg_r = arg->get_root(); enode* arg_r = arg->get_root();
enode* n_r = n->get_root(); enode* n_r = n->get_root();
m_ensure_concat.reset();
auto ensure_concat = [&](unsigned lo, unsigned mid, unsigned hi) { auto ensure_concat = [&](unsigned lo, unsigned mid, unsigned hi) {
// verbose_stream() << lo << " " << mid << " " << hi << "\n";
TRACE("bv", tout << "ensure-concat " << lo << " " << mid << " " << hi << "\n"); TRACE("bv", tout << "ensure-concat " << lo << " " << mid << " " << hi << "\n");
unsigned lo_, hi_; unsigned lo_, hi_;
for (enode* p1 : enode_parents(n)) for (enode* p1 : enode_parents(n))
@ -212,14 +221,14 @@ namespace euf {
TRACE("bv", tout << "propagate-above " << g.bpp(b) << "\n"); TRACE("bv", tout << "propagate-above " << g.bpp(b) << "\n");
for (enode* sib : enode_class(b)) for (enode* sib : enode_class(b))
if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi1 + 1 == lo2) if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi1 + 1 == lo2)
ensure_concat(lo1, hi1, hi2); m_ensure_concat.push_back({lo1, hi1, hi2});
}; };
auto propagate_below = [&](enode* a) { auto propagate_below = [&](enode* a) {
TRACE("bv", tout << "propagate-below " << g.bpp(a) << "\n"); TRACE("bv", tout << "propagate-below " << g.bpp(a) << "\n");
for (enode* sib : enode_class(a)) for (enode* sib : enode_class(a))
if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi2 + 1 == lo1) if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi2 + 1 == lo1)
ensure_concat(lo2, hi2, hi1); m_ensure_concat.push_back({lo2, hi2, hi1});
}; };
for (enode* p : enode_parents(n)) { for (enode* p : enode_parents(n)) {
@ -230,6 +239,10 @@ namespace euf {
propagate_above(a); propagate_above(a);
} }
} }
for (auto [lo, mid, hi] : m_ensure_concat)
ensure_concat(lo, mid, hi);
} }
class bv_plugin::undo_split : public trail { class bv_plugin::undo_split : public trail {
@ -432,13 +445,14 @@ namespace euf {
delta += width(arg); delta += width(arg);
} }
} }
} for (auto p : euf::enode_parents(sib)) {
for (auto p : euf::enode_parents(n->get_root())) { if (bv.is_extract(p->get_expr(), lo, hi, e)) {
if (bv.is_extract(p->get_expr(), lo, hi, e)) { SASSERT(g.find(e)->get_root() == n->get_root());
SASSERT(g.find(e)->get_root() == n->get_root()); m_todo.push_back({ p, offset + lo });
m_todo.push_back({ p, offset + lo }); }
} }
} }
} }
clear_offsets(); clear_offsets();
} }
@ -462,15 +476,15 @@ namespace euf {
auto child = g.find(e); auto child = g.find(e);
m_todo.push_back({ child, offset + lo }); m_todo.push_back({ child, offset + lo });
} }
} for (auto p : euf::enode_parents(sib)) {
for (auto p : euf::enode_parents(n->get_root())) { if (bv.is_concat(p->get_expr())) {
if (bv.is_concat(p->get_expr())) { unsigned delta = 0;
unsigned delta = 0; for (unsigned j = p->num_args(); j-- > 0; ) {
for (unsigned j = p->num_args(); j-- > 0; ) { auto arg = p->get_arg(j);
auto arg = p->get_arg(j); if (arg->get_root() == n->get_root())
if (arg->get_root() == n->get_root()) m_todo.push_back({ p, offset + delta });
m_todo.push_back({ p, offset + delta }); delta += width(arg);
delta += width(arg); }
} }
} }
} }
@ -511,6 +525,9 @@ namespace euf {
m_offsets.reserve(n->get_root_id() + 1); m_offsets.reserve(n->get_root_id() + 1);
m_offsets[n->get_root_id()].reset(); m_offsets[n->get_root_id()].reset();
} }
for (auto const& off : m_offsets) {
SASSERT(off.empty());
}
m_jtodo.reset(); m_jtodo.reset();
return; return;
} }
@ -521,20 +538,27 @@ namespace euf {
just.push_back({ n, sib, j }); just.push_back({ n, sib, j });
for (unsigned j = sib->num_args(); j-- > 0; ) { for (unsigned j = sib->num_args(); j-- > 0; ) {
auto arg = sib->get_arg(j); auto arg = sib->get_arg(j);
m_jtodo.push_back({ arg, offset + delta, j2 }); m_jtodo.push_back({ arg, offs + delta, j2 });
delta += width(arg); delta += width(arg);
} }
} }
} for (auto p : euf::enode_parents(sib)) {
for (auto p : euf::enode_parents(n->get_root())) { if (bv.is_extract(p->get_expr(), lo, hi, e)) {
if (bv.is_extract(p->get_expr(), lo, hi, e)) { SASSERT(g.find(e)->get_root() == n->get_root());
SASSERT(g.find(e)->get_root() == n->get_root()); unsigned j2 = just.size();
unsigned j2 = just.size(); just.push_back({ g.find(e), n, j });
just.push_back({ g.find(e), n, j}); m_jtodo.push_back({ p, offs + lo, j2 });
m_jtodo.push_back({ p, offset + lo, j2}); }
} }
} }
} }
IF_VERBOSE(0,
g.display(verbose_stream());
verbose_stream() << g.bpp(a) << " offset " << offset << " " << g.bpp(b) << "\n";
for (auto const& [n, offset, j] : m_jtodo)
verbose_stream() << g.bpp(n) << " offset " << offset << " " << g.bpp(n->get_root()) << "\n";
);
UNREACHABLE(); UNREACHABLE();
} }

View file

@ -72,6 +72,7 @@ namespace euf {
bool unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys); bool unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys);
bool unfold_sub(enode* x, enode_vector& xs); bool unfold_sub(enode* x, enode_vector& xs);
void merge(enode_vector& xs, enode_vector& ys, justification j); void merge(enode_vector& xs, enode_vector& ys, justification j);
svector<std::tuple<unsigned, unsigned, unsigned>> m_ensure_concat;
void propagate_extract(enode* n); void propagate_extract(enode* n);
void propagate_values(enode* n); void propagate_values(enode* n);
@ -96,7 +97,7 @@ namespace euf {
~bv_plugin() override {} ~bv_plugin() override {}
unsigned get_id() const override { return bv.get_family_id(); } theory_id get_id() const override { return bv.get_family_id(); }
void register_node(enode* n) override; void register_node(enode* n) override;

View file

@ -135,7 +135,7 @@ namespace euf {
enode* prev = this; enode* prev = this;
justification js = m_justification; justification js = m_justification;
prev->m_target = nullptr; prev->m_target = nullptr;
prev->m_justification = justification::axiom(); prev->m_justification = justification::axiom(null_theory_id);
while (curr != nullptr) { while (curr != nullptr) {
enode* new_curr = curr->m_target; enode* new_curr = curr->m_target;
justification new_js = curr->m_justification; justification new_js = curr->m_justification;

View file

@ -36,10 +36,6 @@ namespace euf {
typedef std::pair<enode*,bool> enode_bool_pair; typedef std::pair<enode*,bool> enode_bool_pair;
typedef svector<enode_bool_pair> enode_bool_pair_vector; typedef svector<enode_bool_pair> enode_bool_pair_vector;
typedef id_var_list<> th_var_list; typedef id_var_list<> th_var_list;
typedef int theory_var;
typedef int theory_id;
const theory_var null_theory_var = -1;
const theory_id null_theory_id = -1;
class enode { class enode {
expr* m_expr = nullptr; expr* m_expr = nullptr;

View file

@ -28,6 +28,11 @@ namespace euf {
class enode; class enode;
typedef int theory_var;
typedef int theory_id;
const theory_var null_theory_var = -1;
const theory_id null_theory_id = -1;
class justification { class justification {
public: public:
typedef stacked_dependency_manager<justification> dependency_manager; typedef stacked_dependency_manager<justification> dependency_manager;
@ -42,6 +47,7 @@ namespace euf {
}; };
kind_t m_kind; kind_t m_kind;
union { union {
int m_theory_id;
bool m_comm; bool m_comm;
enode* m_n1; enode* m_n1;
}; };
@ -76,19 +82,27 @@ namespace euf {
m_n2(n2) m_n2(n2)
{} {}
public: justification(int theory_id):
justification():
m_kind(kind_t::axiom_t), m_kind(kind_t::axiom_t),
m_comm(false), m_theory_id(theory_id),
m_external(nullptr) m_external(nullptr)
{} {}
static justification axiom() { return justification(); } public:
justification():
m_kind(kind_t::axiom_t),
m_theory_id(null_theory_id),
m_external(nullptr)
{}
static justification axiom(int theory_id) { return justification(theory_id); }
static justification congruence(bool c, uint64_t ts) { return justification(c, ts); } static justification congruence(bool c, uint64_t ts) { return justification(c, ts); }
static justification external(void* ext) { return justification(ext); } static justification external(void* ext) { return justification(ext); }
static justification dependent(dependency* d) { return justification(d, 1); } static justification dependent(dependency* d) { return justification(d, 1); }
static justification equality(enode* a, enode* b) { return justification(a, b); } static justification equality(enode* a, enode* b) { return justification(a, b); }
bool is_axiom() const { return m_kind == kind_t::axiom_t; }
bool is_external() const { return m_kind == kind_t::external_t; } bool is_external() const { return m_kind == kind_t::external_t; }
bool is_congruence() const { return m_kind == kind_t::congruence_t; } bool is_congruence() const { return m_kind == kind_t::congruence_t; }
bool is_commutative() const { return m_comm; } bool is_commutative() const { return m_comm; }
@ -98,6 +112,7 @@ namespace euf {
enode* lhs() const { SASSERT(is_equality()); return m_n1; } enode* lhs() const { SASSERT(is_equality()); return m_n1; }
enode* rhs() const { SASSERT(is_equality()); return m_n2; } enode* rhs() const { SASSERT(is_equality()); return m_n2; }
uint64_t timestamp() const { SASSERT(is_congruence()); return m_timestamp; } uint64_t timestamp() const { SASSERT(is_congruence()); return m_timestamp; }
theory_id get_theory_id() const { SASSERT(is_axiom()); return m_theory_id; }
template <typename T> template <typename T>
T* ext() const { SASSERT(is_external()); return static_cast<T*>(m_external); } T* ext() const { SASSERT(is_external()); return static_cast<T*>(m_external); }
@ -106,7 +121,7 @@ namespace euf {
case kind_t::external_t: case kind_t::external_t:
return external(copy_justification(m_external)); return external(copy_justification(m_external));
case kind_t::axiom_t: case kind_t::axiom_t:
return axiom(); return axiom(m_theory_id);
case kind_t::congruence_t: case kind_t::congruence_t:
return congruence(m_comm, m_timestamp); return congruence(m_comm, m_timestamp);
case kind_t::dependent_t: case kind_t::dependent_t:
@ -114,7 +129,7 @@ namespace euf {
return dependent(m_dependency); return dependent(m_dependency);
default: default:
UNREACHABLE(); UNREACHABLE();
return axiom(); return axiom(-1);
} }
} }

View file

@ -26,12 +26,13 @@ namespace euf {
} }
void plugin::push_merge(enode* a, enode* b, justification j) { void plugin::push_merge(enode* a, enode* b, justification j) {
TRACE("euf", tout << "push-merge " << g.bpp(a) << " == " << g.bpp(b) << " " << j << "\n");
g.push_merge(a, b, j); g.push_merge(a, b, j);
} }
void plugin::push_merge(enode* a, enode* b) { void plugin::push_merge(enode* a, enode* b) {
TRACE("plugin", tout << g.bpp(a) << " == " << g.bpp(b) << "\n"); TRACE("plugin", tout << g.bpp(a) << " == " << g.bpp(b) << "\n");
g.push_merge(a, b, justification::axiom()); g.push_merge(a, b, justification::axiom(get_id()));
} }
enode* plugin::mk(expr* e, unsigned n, enode* const* args) { enode* plugin::mk(expr* e, unsigned n, enode* const* args) {

View file

@ -40,7 +40,7 @@ namespace euf {
virtual ~plugin() {} virtual ~plugin() {}
virtual unsigned get_id() const = 0; virtual theory_id get_id() const = 0;
virtual void register_node(enode* n) = 0; virtual void register_node(enode* n) = 0;

View file

@ -37,7 +37,7 @@ namespace euf {
~specrel_plugin() override {} ~specrel_plugin() override {}
unsigned get_id() const override { return sp.get_family_id(); } theory_id get_id() const override { return sp.get_family_id(); }
void register_node(enode* n) override; void register_node(enode* n) override;