3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-15 18:36:16 +00:00

making var_eqs into a template

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
Lev Nachmanson 2019-05-13 15:03:41 -07:00
parent 50d3e67e61
commit 94e3078920
6 changed files with 171 additions and 220 deletions

View file

@ -39,7 +39,6 @@ z3_add_component(lp
square_dense_submatrix.cpp square_dense_submatrix.cpp
square_sparse_matrix.cpp square_sparse_matrix.cpp
static_matrix.cpp static_matrix.cpp
var_eqs.cpp
COMPONENT_DEPENDENCIES COMPONENT_DEPENDENCIES
util util
polynomial polynomial

View file

@ -399,5 +399,4 @@ std::ostream& emonomials::display(std::ostream& out) const {
return out; return out;
} }
} }

View file

@ -29,7 +29,7 @@ namespace nla {
class core; class core;
class emonomials : public var_eqs_merge_handler { class emonomials {
/** /**
\brief singly-lined cyclic list of monomial indices where variable occurs. \brief singly-lined cyclic list of monomial indices where variable occurs.
Each variable points to the head and tail of the cyclic list. Each variable points to the head and tail of the cyclic list.
@ -75,7 +75,7 @@ class emonomials : public var_eqs_merge_handler {
}; };
mutable svector<lpvar> m_find_key; // the key used when looking for a monomial with the specific variables mutable svector<lpvar> m_find_key; // the key used when looking for a monomial with the specific variables
var_eqs& m_ve; var_eqs<emonomials>& m_ve;
mutable vector<monomial> m_monomials; // set of monomials mutable vector<monomial> m_monomials; // set of monomials
mutable unsigned_vector m_var2index; // var_mIndex -> mIndex mutable unsigned_vector m_var2index; // var_mIndex -> mIndex
unsigned_vector m_lim; // backtracking point unsigned_vector m_lim; // backtracking point
@ -110,7 +110,7 @@ public:
push and pop on emonomials calls push/pop on var_eqs, so no push and pop on emonomials calls push/pop on var_eqs, so no
other calls to push/pop to the var_eqs should take place. other calls to push/pop to the var_eqs should take place.
*/ */
emonomials(var_eqs& ve): emonomials(var_eqs<emonomials>& ve):
m_ve(ve), m_ve(ve),
m_visited(0), m_visited(0),
m_cg_hash(*this), m_cg_hash(*this),
@ -288,11 +288,11 @@ public:
these are merge event handlers to interect the union-find handlers. these are merge event handlers to interect the union-find handlers.
r2 becomes the new root. r2 is the root of v2, r1 is the old root of v1 r2 becomes the new root. r2 is the root of v2, r1 is the old root of v1
*/ */
void merge_eh(signed_var r2, signed_var r1, signed_var v2, signed_var v1) override; void merge_eh(signed_var r2, signed_var r1, signed_var v2, signed_var v1);
void after_merge_eh(signed_var r2, signed_var r1, signed_var v2, signed_var v1) override; void after_merge_eh(signed_var r2, signed_var r1, signed_var v2, signed_var v1);
void unmerge_eh(signed_var r2, signed_var r1) override; void unmerge_eh(signed_var r2, signed_var r1);
bool is_monomial_var(lpvar v) const { return m_var2index.get(v, UINT_MAX) != UINT_MAX; } bool is_monomial_var(lpvar v) const { return m_var2index.get(v, UINT_MAX) != UINT_MAX; }
}; };

View file

@ -76,7 +76,7 @@ public:
class core { class core {
public: public:
var_eqs m_evars; var_eqs<emonomials> m_evars;
lp::lar_solver& m_lar_solver; lp::lar_solver& m_lar_solver;
vector<lemma> * m_lemma_vec; vector<lemma> * m_lemma_vec;
svector<lpvar> m_to_refine; svector<lpvar> m_to_refine;

View file

@ -1,189 +0,0 @@
/*++
Copyright (c) 2017 Microsoft Corporation
Module Name:
<name>
Abstract:
<abstract>
Author:
Nikolaj Bjorner (nbjorner)
Lev Nachmanson (levnach)
Revision History:
--*/
#include "util/lp/var_eqs.h"
namespace nla {
var_eqs::var_eqs(): m_merge_handler(nullptr), m_uf(*this), m_stack(*this) {}
void var_eqs::push() {
m_trail_lim.push_back(m_trail.size());
m_stack.push_scope();
}
void var_eqs::pop(unsigned n) {
unsigned old_sz = m_trail_lim[m_trail_lim.size() - n];
for (unsigned i = m_trail.size(); i-- > old_sz; ) {
auto const& sv = m_trail[i];
m_eqs[sv.first.index()].pop_back();
m_eqs[sv.second.index()].pop_back();
m_eqs[(~sv.first).index()].pop_back();
m_eqs[(~sv.second).index()].pop_back();
}
m_trail_lim.shrink(m_trail_lim.size() - n);
m_trail.shrink(old_sz);
m_stack.pop_scope(n);
}
void var_eqs::merge(signed_var v1, signed_var v2, eq_justification const& j) {
unsigned max_i = std::max(v1.index(), v2.index()) + 2;
m_eqs.reserve(max_i);
while (m_uf.get_num_vars() <= max_i) m_uf.mk_var();
m_trail.push_back(std::make_pair(v1, v2));
m_uf.merge(v1.index(), v2.index());
m_uf.merge((~v1).index(), (~v2).index());
m_eqs[v1.index()].push_back(eq_edge(v2, j));
m_eqs[v2.index()].push_back(eq_edge(v1, j));
m_eqs[(~v1).index()].push_back(eq_edge(~v2, j));
m_eqs[(~v2).index()].push_back(eq_edge(~v1, j));
}
signed_var var_eqs::find(signed_var v) const {
if (v.index() >= m_uf.get_num_vars()) {
return v;
}
unsigned idx = m_uf.find(v.index());
return signed_var(idx);
}
void var_eqs::explain_dfs(signed_var v1, signed_var v2, lp::explanation& e) const {
SASSERT(find(v1) == find(v2));
if (v1 == v2) {
return;
}
m_todo.push_back(var_frame(v1, 0));
m_justtrail.reset();
m_marked.reserve(m_eqs.size(), false);
SASSERT(m_marked_trail.empty());
m_marked[v1.index()] = true;
m_marked_trail.push_back(v1.index());
while (true) {
SASSERT(!m_todo.empty());
var_frame& f = m_todo.back();
signed_var v = f.m_var;
if (v == v2) {
break;
}
auto const& next = m_eqs[v.index()];
bool seen_all = true;
unsigned sz = next.size();
for (unsigned i = f.m_index; seen_all && i < sz; ++i) {
eq_edge const& jv = next[i];
signed_var v3 = jv.m_var;
if (!m_marked[v3.index()]) {
seen_all = false;
f.m_index = i + 1;
m_todo.push_back(var_frame(v3, 0));
m_justtrail.push_back(jv.m_just);
m_marked_trail.push_back(v3.index());
m_marked[v3.index()] = true;
}
}
if (seen_all) {
m_todo.pop_back();
m_justtrail.pop_back();
}
}
for (eq_justification const& j : m_justtrail) {
j.explain(e);
}
m_stats.m_num_explains += m_justtrail.size();
m_stats.m_num_explain_calls++;
m_todo.reset();
m_justtrail.reset();
for (unsigned idx : m_marked_trail) {
m_marked[idx] = false;
}
m_marked_trail.reset();
// IF_VERBOSE(2, verbose_stream() << (double)m_stats.m_num_explains / m_stats.m_num_explain_calls << "\n");
}
void var_eqs::explain_bfs(signed_var v1, signed_var v2, lp::explanation& e) const {
SASSERT(find(v1) == find(v2));
if (v1 == v2) {
return;
}
m_todo.push_back(var_frame(v1, 0));
m_justtrail.push_back(eq_justification({}));
m_marked.reserve(m_eqs.size(), false);
SASSERT(m_marked_trail.empty());
m_marked[v1.index()] = true;
m_marked_trail.push_back(v1.index());
unsigned head = 0;
for (; ; ++head) {
var_frame& f = m_todo[head];
signed_var v = f.m_var;
if (v == v2) {
break;
}
auto const& next = m_eqs[v.index()];
unsigned sz = next.size();
for (unsigned i = sz; i-- > 0; ) {
eq_edge const& jv = next[i];
signed_var v3 = jv.m_var;
if (!m_marked[v3.index()]) {
m_todo.push_back(var_frame(v3, head));
m_justtrail.push_back(jv.m_just);
m_marked_trail.push_back(v3.index());
m_marked[v3.index()] = true;
}
}
}
while (head != 0) {
m_justtrail[head].explain(e);
head = m_todo[head].m_index;
++m_stats.m_num_explains;
}
++m_stats.m_num_explain_calls;
m_todo.reset();
m_justtrail.reset();
for (unsigned idx : m_marked_trail) {
m_marked[idx] = false;
}
m_marked_trail.reset();
// IF_VERBOSE(2, verbose_stream() << (double)m_stats.m_num_explains / m_stats.m_num_explain_calls << "\n");
}
std::ostream& var_eqs::display(std::ostream& out) const {
m_uf.display(out);
unsigned idx = 0;
for (auto const& edges : m_eqs) {
if (!edges.empty()) {
auto v = signed_var(idx);
out << v << " root: " << find(v) << " : ";
for (auto const& jv : edges) {
out << jv.m_var << " ";
}
out << "\n";
}
++idx;
}
return out;
}
}

View file

@ -45,13 +45,7 @@ public:
} }
}; };
class var_eqs_merge_handler { template <typename T>
public:
virtual void merge_eh(signed_var r2, signed_var r1, signed_var v2, signed_var v1) = 0;
virtual void after_merge_eh(signed_var r2, signed_var r1, signed_var v2, signed_var v1) = 0;
virtual void unmerge_eh(signed_var r2, signed_var r1) = 0;
};
class var_eqs { class var_eqs {
struct eq_edge { struct eq_edge {
signed_var m_var; signed_var m_var;
@ -70,7 +64,7 @@ class var_eqs {
stats() { memset(this, 0, sizeof(*this)); } stats() { memset(this, 0, sizeof(*this)); }
}; };
var_eqs_merge_handler* m_merge_handler; T* m_merge_handler;
union_find<var_eqs> m_uf; union_find<var_eqs> m_uf;
svector<std::pair<signed_var, signed_var>> m_trail; svector<std::pair<signed_var, signed_var>> m_trail;
unsigned_vector m_trail_lim; unsigned_vector m_trail_lim;
@ -83,30 +77,61 @@ class var_eqs {
mutable svector<eq_justification> m_justtrail; mutable svector<eq_justification> m_justtrail;
mutable stats m_stats; mutable stats m_stats;
public: public:
var_eqs(); var_eqs(): m_merge_handler(nullptr), m_uf(*this), m_stack(*this) {}
/** /**
\brief push a scope \brief push a scope */
*/ void push() {
void push(); m_trail_lim.push_back(m_trail.size());
m_stack.push_scope();
}
/** /**
\brief pop n scopes \brief pop n scopes
*/ */
void pop(unsigned n); void pop(unsigned n) {
unsigned old_sz = m_trail_lim[m_trail_lim.size() - n];
for (unsigned i = m_trail.size(); i-- > old_sz; ) {
auto const& sv = m_trail[i];
m_eqs[sv.first.index()].pop_back();
m_eqs[sv.second.index()].pop_back();
m_eqs[(~sv.first).index()].pop_back();
m_eqs[(~sv.second).index()].pop_back();
}
m_trail_lim.shrink(m_trail_lim.size() - n);
m_trail.shrink(old_sz);
m_stack.pop_scope(n);
}
/** /**
\brief merge equivalence classes for v1, v2 with justification j \brief merge equivalence classes for v1, v2 with justification j
*/ */
void merge(signed_var v1, signed_var v2, eq_justification const& j); void merge(signed_var v1, signed_var v2, eq_justification const& j) {
unsigned max_i = std::max(v1.index(), v2.index()) + 2;
m_eqs.reserve(max_i);
while (m_uf.get_num_vars() <= max_i) m_uf.mk_var();
m_trail.push_back(std::make_pair(v1, v2));
m_uf.merge(v1.index(), v2.index());
m_uf.merge((~v1).index(), (~v2).index());
m_eqs[v1.index()].push_back(eq_edge(v2, j));
m_eqs[v2.index()].push_back(eq_edge(v1, j));
m_eqs[(~v1).index()].push_back(eq_edge(~v2, j));
m_eqs[(~v2).index()].push_back(eq_edge(~v1, j));
}
void merge_plus(lpvar v1, lpvar v2, eq_justification const& j) { merge(signed_var(v1, false), signed_var(v2, false), j); } void merge_plus(lpvar v1, lpvar v2, eq_justification const& j) { merge(signed_var(v1, false), signed_var(v2, false), j); }
void merge_minus(lpvar v1, lpvar v2, eq_justification const& j) { merge(signed_var(v1, false), signed_var(v2, true), j); } void merge_minus(lpvar v1, lpvar v2, eq_justification const& j) { merge(signed_var(v1, false), signed_var(v2, true), j); }
/** /**
\brief find equivalence class representative for v \brief find equivalence class representative for v
*/ */
signed_var find(signed_var v) const; signed_var find(signed_var v) const {
if (v.index() >= m_uf.get_num_vars()) {
return v;
}
unsigned idx = m_uf.find(v.index());
return signed_var(idx);
}
inline signed_var find(lpvar j) const { inline signed_var find(lpvar j) const {
return find(signed_var(j, false)); return find(signed_var(j, false));
@ -132,8 +157,109 @@ public:
\brief Returns eq_justifications for \brief Returns eq_justifications for
\pre find(v1) == find(v2) \pre find(v1) == find(v2)
*/ */
void explain_dfs(signed_var v1, signed_var v2, lp::explanation& e) const; void explain_dfs(signed_var v1, signed_var v2, lp::explanation& e) const {
void explain_bfs(signed_var v1, signed_var v2, lp::explanation& e) const; SASSERT(find(v1) == find(v2));
if (v1 == v2) {
return;
}
m_todo.push_back(var_frame(v1, 0));
m_justtrail.reset();
m_marked.reserve(m_eqs.size(), false);
SASSERT(m_marked_trail.empty());
m_marked[v1.index()] = true;
m_marked_trail.push_back(v1.index());
while (true) {
SASSERT(!m_todo.empty());
var_frame& f = m_todo.back();
signed_var v = f.m_var;
if (v == v2) {
break;
}
auto const& next = m_eqs[v.index()];
bool seen_all = true;
unsigned sz = next.size();
for (unsigned i = f.m_index; seen_all && i < sz; ++i) {
eq_edge const& jv = next[i];
signed_var v3 = jv.m_var;
if (!m_marked[v3.index()]) {
seen_all = false;
f.m_index = i + 1;
m_todo.push_back(var_frame(v3, 0));
m_justtrail.push_back(jv.m_just);
m_marked_trail.push_back(v3.index());
m_marked[v3.index()] = true;
}
}
if (seen_all) {
m_todo.pop_back();
m_justtrail.pop_back();
}
}
for (eq_justification const& j : m_justtrail) {
j.explain(e);
}
m_stats.m_num_explains += m_justtrail.size();
m_stats.m_num_explain_calls++;
m_todo.reset();
m_justtrail.reset();
for (unsigned idx : m_marked_trail) {
m_marked[idx] = false;
}
m_marked_trail.reset();
// IF_VERBOSE(2, verbose_stream() << (double)m_stats.m_num_explains / m_stats.m_num_explain_calls << "\n");
}
void explain_bfs(signed_var v1, signed_var v2, lp::explanation& e) const {
SASSERT(find(v1) == find(v2));
if (v1 == v2) {
return;
}
m_todo.push_back(var_frame(v1, 0));
m_justtrail.push_back(eq_justification({}));
m_marked.reserve(m_eqs.size(), false);
SASSERT(m_marked_trail.empty());
m_marked[v1.index()] = true;
m_marked_trail.push_back(v1.index());
unsigned head = 0;
for (; ; ++head) {
var_frame& f = m_todo[head];
signed_var v = f.m_var;
if (v == v2) {
break;
}
auto const& next = m_eqs[v.index()];
unsigned sz = next.size();
for (unsigned i = sz; i-- > 0; ) {
eq_edge const& jv = next[i];
signed_var v3 = jv.m_var;
if (!m_marked[v3.index()]) {
m_todo.push_back(var_frame(v3, head));
m_justtrail.push_back(jv.m_just);
m_marked_trail.push_back(v3.index());
m_marked[v3.index()] = true;
}
}
}
while (head != 0) {
m_justtrail[head].explain(e);
head = m_todo[head].m_index;
++m_stats.m_num_explains;
}
++m_stats.m_num_explain_calls;
m_todo.reset();
m_justtrail.reset();
for (unsigned idx : m_marked_trail) {
m_marked[idx] = false;
}
m_marked_trail.reset();
// IF_VERBOSE(2, verbose_stream() << (double)m_stats.m_num_explains / m_stats.m_num_explain_calls << "\n");
}
inline void explain(signed_var v1, signed_var v2, lp::explanation& e) const { inline void explain(signed_var v1, signed_var v2, lp::explanation& e) const {
explain_bfs(v1, v2, e); explain_bfs(v1, v2, e);
@ -176,10 +302,25 @@ public:
eq_class equiv_class(lpvar v) { return equiv_class(signed_var(v, false)); } eq_class equiv_class(lpvar v) { return equiv_class(signed_var(v, false)); }
std::ostream& display(std::ostream& out) const; std::ostream& display(std::ostream& out) const {
m_uf.display(out);
unsigned idx = 0;
for (auto const& edges : m_eqs) {
if (!edges.empty()) {
auto v = signed_var(idx);
out << v << " root: " << find(v) << " : ";
for (auto const& jv : edges) {
out << jv.m_var << " ";
}
out << "\n";
}
++idx;
}
return out;
}
// union find event handlers // union find event handlers
void set_merge_handler(var_eqs_merge_handler* mh) { m_merge_handler = mh; } void set_merge_handler(T* mh) { m_merge_handler = mh; }
// this method is required by union_find // this method is required by union_find
trail_stack<var_eqs> & get_trail_stack() { return m_stack; } trail_stack<var_eqs> & get_trail_stack() { return m_stack; }
@ -203,6 +344,7 @@ public:
} }
}; // end of var_eqs }; // end of var_eqs
inline std::ostream& operator<<(var_eqs const& ve, std::ostream& out) { return ve.display(out); } template <typename T>
std::ostream& operator<<(var_eqs<T> const& ve, std::ostream& out) { return ve.display(out); }
} }