3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-15 13:28:47 +00:00

rewrite horner scheme on top of nex_expr as a pointer

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
Lev Nachmanson 2019-08-16 18:10:12 -07:00
parent 9fbd0da931
commit a7449494a9
3 changed files with 162 additions and 75 deletions

View file

@ -35,7 +35,7 @@ class cross_nested {
}; };
// fields // fields
nex_sum * m_e; nex * m_e;
std::function<bool (const nex*)> m_call_on_result; std::function<bool (const nex*)> m_call_on_result;
std::function<bool (unsigned)> m_var_is_fixed; std::function<bool (unsigned)> m_var_is_fixed;
bool m_done; bool m_done;
@ -43,6 +43,7 @@ class cross_nested {
std::unordered_map<lpvar, unsigned> m_powers; std::unordered_map<lpvar, unsigned> m_powers;
vector<nex*> m_allocated; vector<nex*> m_allocated;
vector<nex*> m_b_vec; vector<nex*> m_b_vec;
vector<nex*> m_b_split_vec;
public: public:
cross_nested(std::function<bool (const nex*)> call_on_result, cross_nested(std::function<bool (const nex*)> call_on_result,
std::function<bool (unsigned)> var_is_fixed): std::function<bool (unsigned)> var_is_fixed):
@ -51,16 +52,16 @@ public:
m_done(false) m_done(false)
{} {}
void run(nex_sum *e) { void run(nex *e) {
m_e = e; m_e = e;
vector<nex_sum*> front; vector<nex**> front;
explore_expr_on_front_elem(m_e, front); explore_expr_on_front_elem(m_e, front);
} }
static nex_sum* pop_back(vector<nex_sum*>& front) { static nex** pop_front(vector<nex**>& front) {
nex_sum* c = front.back(); nex** c = front.back();
TRACE("nla_cn", tout << *c << "\n";); TRACE("nla_cn", tout << **c << "\n";);
front.pop_back(); front.pop_back();
return c; return c;
} }
@ -70,6 +71,14 @@ public:
m_allocated.push_back(r); m_allocated.push_back(r);
return r; return r;
} }
template <typename T>
void add_children(T) { }
template <typename T, typename K, typename ...Args>
void add_children(T r, K e, Args ... es) {
r->add_child(e);
add_children(r, es ...);
}
nex_sum* mk_sum(const vector<nex*>& v) { nex_sum* mk_sum(const vector<nex*>& v) {
auto r = new nex_sum(); auto r = new nex_sum();
@ -78,40 +87,41 @@ public:
return r; return r;
} }
nex_sum* mk_sum(nex *a, nex* b) { nex_mul* mk_mul(const vector<nex*>& v) {
auto r = new nex_sum(); auto r = new nex_mul();
m_allocated.push_back(r); m_allocated.push_back(r);
r->children().push_back(a); r->children() = v;
r->children().push_back(b);
return r; return r;
} }
template <typename K, typename...Args>
nex_sum* mk_sum(K e, Args... es) {
auto r = new nex_sum();
m_allocated.push_back(r);
r->add_child(e);
add_children(r, es...);
return r;
}
nex_var* mk_var(lpvar j) { nex_var* mk_var(lpvar j) {
auto r = new nex_var(j); auto r = new nex_var(j);
m_allocated.push_back(r); m_allocated.push_back(r);
return r; return r;
} }
nex_mul* mk_mul() { nex_mul* mk_mul() {
auto r = new nex_mul(); auto r = new nex_mul();
m_allocated.push_back(r); m_allocated.push_back(r);
return r; return r;
} }
nex_mul* mk_mul(nex * a, nex * b) { template <typename K, typename...Args>
nex_mul* mk_mul(K e, Args... es) {
auto r = new nex_mul(); auto r = new nex_mul();
m_allocated.push_back(r); m_allocated.push_back(r);
r->add_child(a); r->add_child(b); add_children(r, e, es...);
return r; return r;
} }
nex_mul* mk_mul(nex * a, nex * b, nex *c) {
auto r = new nex_mul();
m_allocated.push_back(r);
r->add_child(a); r->add_child(b); r->add_child(c);
return r;
}
nex_scalar* mk_scalar(const rational& v) { nex_scalar* mk_scalar(const rational& v) {
auto r = new nex_scalar(v); auto r = new nex_scalar(v);
m_allocated.push_back(r); m_allocated.push_back(r);
@ -120,8 +130,32 @@ public:
nex * mk_div(const nex* a, lpvar j) { nex * mk_div(const nex* a, lpvar j) {
SASSERT(false); TRACE("nla_cn_details", tout << "a=" << *a << ", v" << j << "\n";);
return nullptr; SASSERT((a->is_mul() && a->contains(j)) || (a->is_var() && to_var(a)->var() == j));
if (a->is_var())
return mk_scalar(rational(1));
m_b_vec.clear();
bool seenj = false;
for (nex* c : to_mul(a)->children()) {
if (!seenj) {
if (c->contains(j)) {
if (!c->is_var())
m_b_vec.push_back(mk_div(c, j));
seenj = true;
continue;
}
}
m_b_vec.push_back(c);
}
if (m_b_vec.size() > 1) {
return mk_mul(m_b_vec);
}
if (m_b_vec.size() == 1) {
return m_b_vec[0];
}
SASSERT(m_b_vec.size() == 0);
return mk_scalar(rational(1));
} }
nex * mk_div(const nex* a, const nex* b) { nex * mk_div(const nex* a, const nex* b) {
@ -218,14 +252,14 @@ public:
return false; return false;
} }
bool proceed_with_common_factor(nex*& c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) { bool proceed_with_common_factor(nex*& c, vector<nex**>& front, const vector<std::pair<lpvar, occ>> & occurences) {
TRACE("nla_cn", tout << "c=" << *c << "\n";); TRACE("nla_cn", tout << "c=" << *c << "\n";);
nex* f = extract_common_factor(c, occurences); nex* f = extract_common_factor(c, occurences);
if (f == nullptr) if (f == nullptr)
return false; return false;
nex_sum* c_over_f = to_sum(mk_div(c, f)); nex* c_over_f = mk_div(c, f);
c_over_f->simplify(); to_sum(c_over_f)->simplify();
c = mk_mul(f, c_over_f); c = mk_mul(f, c_over_f);
TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << *c << "\ne = " << *m_e << "\n";); TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << *c << "\ne = " << *m_e << "\n";);
@ -233,28 +267,28 @@ public:
return true; return true;
} }
static void push(vector<nex_sum*>& front, nex_sum* e) { static void push(vector<nex**>& front, nex** e) {
TRACE("nla_cn", tout << *e << "\n";); TRACE("nla_cn", tout << **e << "\n";);
front.push_back(e); front.push_back(e);
} }
static vector<nex_sum*> copy_front(const vector<nex_sum*>& front) { static vector<nex*> copy_front(const vector<nex**>& front) {
vector<nex_sum*> v; vector<nex*> v;
for (nex_sum* n: front) for (nex** n: front)
v.push_back(n); v.push_back(*n);
return v; return v;
} }
static void restore_front(const vector<nex_sum*> &copy, vector<nex_sum*>& front) { static void restore_front(const vector<nex*> &copy, vector<nex**>& front) {
SASSERT(copy.size() == front.size()); SASSERT(copy.size() == front.size());
for (unsigned i = 0; i < front.size(); i++) for (unsigned i = 0; i < front.size(); i++)
front[i] = copy[i]; *(front[i]) = copy[i];
} }
void explore_expr_on_front_elem_occs(nex* c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) { void explore_expr_on_front_elem_occs(nex* &c, vector<nex**>& front, const vector<std::pair<lpvar, occ>> & occurences) {
if (proceed_with_common_factor(c, front, occurences)) if (proceed_with_common_factor(c, front, occurences))
return; return;
TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_vector_of_ptrs(front, tout) << "\n";); TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_front(front, tout) << "\n";);
nex* copy_of_c = c; nex* copy_of_c = c;
auto copy_of_front = copy_front(front); auto copy_of_front = copy_front(front);
for(auto& p : occurences) { for(auto& p : occurences) {
@ -269,11 +303,12 @@ public:
explore_of_expr_on_sum_and_var(c, j, front); explore_of_expr_on_sum_and_var(c, j, front);
if (m_done) if (m_done)
return; return;
TRACE("nla_cn", tout << "before restore c=" << *c << ", m_e=" << *m_e << "\n";);
c = copy_of_c; c = copy_of_c;
TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";); TRACE("nla_cn", tout << "after restore c=" << *c << ", m_e=" << *m_e << "\n";);
restore_front(copy_of_front, front); restore_front(copy_of_front, front);
TRACE("nla_cn", tout << "restore c=" << *c << "\n";); TRACE("nla_cn", tout << "restore c=" << *c << "\n";);
TRACE("nla_cn", tout << "m_e=" << m_e << "\n";); TRACE("nla_cn", tout << "m_e=" << *m_e << "\n";);
} }
} }
@ -288,18 +323,18 @@ public:
return out; return out;
} }
void explore_expr_on_front_elem(nex_sum* c, vector<nex_sum*>& front) { void explore_expr_on_front_elem(nex*& c, vector<nex**>& front) {
auto occurences = get_mult_occurences(c); auto occurences = get_mult_occurences(to_sum(c));
TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << ", c occurences="; TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << *c << ", c occurences=";
dump_occurences(tout, occurences) << "; front:"; print_vector_of_ptrs(front, tout) << "\n";); dump_occurences(tout, occurences) << "; front:"; print_front(front, tout) << "\n";);
if (occurences.empty()) { if (occurences.empty()) {
if(front.empty()) { if(front.empty()) {
TRACE("nla_cn", tout << "got the cn form: =" << m_e << "\n";); TRACE("nla_cn", tout << "got the cn form: =" << *m_e << "\n";);
m_done = m_call_on_result(m_e); m_done = m_call_on_result(m_e);
} else { } else {
auto c = pop_back(front); nex* f = *pop_front(front);
explore_expr_on_front_elem(c, front); explore_expr_on_front_elem(f, front);
} }
} else { } else {
explore_expr_on_front_elem_occs(c, front, occurences); explore_expr_on_front_elem_occs(c, front, occurences);
@ -311,15 +346,23 @@ public:
return s.str(); return s.str();
// return (char)('a'+j); // return (char)('a'+j);
} }
// e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form
void explore_of_expr_on_sum_and_var(nex* & c, lpvar j, vector<nex_sum*> front) { std::ostream& print_front(const vector<nex**>& front, std::ostream& out) const {
TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); for (auto e : front) {
out << **e << "\n";
}
return out;
}
// c is the sub expressiond which is going to be changed from sum to the cross nested form
// front will be explored more
void explore_of_expr_on_sum_and_var(nex*& c, lpvar j, vector<nex**> front) {
TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_front(front, tout) << "\n";);
if (!split_with_var(c, j, front)) if (!split_with_var(c, j, front))
return; return;
TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";); TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_front(front, tout) << "\n";);
SASSERT(front.size()); SASSERT(front.size());
auto n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";); auto n = pop_front(front);
explore_expr_on_front_elem(n, front); explore_expr_on_front_elem(*n, front);
} }
void add_var_occs(lpvar j) { void add_var_occs(lpvar j) {
@ -378,7 +421,7 @@ public:
} }
} }
remove_singular_occurences(); remove_singular_occurences();
TRACE("nla_cn_details", tout << "e=" << e << "\noccs="; dump_occurences(tout, m_occurences_map) << "\n";); TRACE("nla_cn_details", tout << "e=" << *e << "\noccs="; dump_occurences(tout, m_occurences_map) << "\n";);
vector<std::pair<lpvar, occ>> ret; vector<std::pair<lpvar, occ>> ret;
for (auto & p : m_occurences_map) for (auto & p : m_occurences_map)
ret.push_back(p); ret.push_back(p);
@ -405,54 +448,57 @@ public:
void pre_split(nex_sum * e, lpvar j, nex_sum* & a, nex* & b) { void pre_split(nex_sum * e, lpvar j, nex_sum* & a, nex* & b) {
a = mk_sum(); a = mk_sum();
m_b_vec.clear(); m_b_split_vec.clear();
for (nex * ce: e->children()) { for (nex * ce: e->children()) {
if (is_divisible_by_var(ce, j)) { if (is_divisible_by_var(ce, j)) {
a->add_child(mk_div(ce , j)); a->add_child(mk_div(ce , j));
} else { } else {
m_b_vec.push_back(ce); m_b_split_vec.push_back(ce);
TRACE("nla_cn_details", tout << "ce = " << *ce << "\n";);
} }
} }
TRACE("nla_cn_details", tout << "a = " << *a << "\n";); TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
SASSERT(a->children().size() >= 2 && m_b_vec.size()); SASSERT(a->children().size() >= 2 && m_b_split_vec.size());
a->simplify(); a->simplify();
if (m_b_vec.size() == 1) { if (m_b_split_vec.size() == 1) {
b = m_b_vec[0]; b = m_b_split_vec[0];
TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
} else { } else {
SASSERT(m_b_vec.size() > 1); SASSERT(m_b_split_vec.size() > 1);
b = mk_sum(m_b_vec); b = mk_sum(m_b_split_vec);
} TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
}
} }
void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) { void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector<nex**> & front, nex* a, nex* b) {
SASSERT(a->is_sum()); SASSERT(a->is_sum());
TRACE("nla_cn_details", tout << "b = " << b << "\n";); TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
e = mk_sum(mk_mul(mk_var(j), a), b); // e = j*a + b e = mk_sum(mk_mul(mk_var(j), a), b); // e = j*a + b
push(front, a); // pushing 'a' nex **ptr_to_a = &(to_mul(to_sum(e)->children()[0]))->children()[1];
TRACE("nla_cn", tout << "push to front " << *a << "\n";); push(front, ptr_to_a);
if (b->is_sum()) { if (b->is_sum()) {
push(front, to_sum(b)); nex **ptr_to_a = &(to_sum(e)->children()[1]);
TRACE("nla_cn", tout << "push to front " << *b << "\n";); push(front, ptr_to_a);
} }
} }
void update_front_with_split(nex* & e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) { void update_front_with_split(nex* & e, lpvar j, vector<nex**> & front, nex* a, nex* b) {
if (b == nullptr) { if (b == nullptr) {
e = mk_mul(mk_var(j), a); e = mk_mul(mk_var(j), a);
push(front, a); push(front, &(to_mul(e)->children()[1]));
TRACE("nla_cn_details", tout << "push to front " << *a << "\n";);
} else { } else {
update_front_with_split_with_non_empty_b(e, j, front, a, b); update_front_with_split_with_non_empty_b(e, j, front, a, b);
} }
} }
// it returns true if the recursion brings a cross-nested form // it returns true if the recursion brings a cross-nested form
bool split_with_var(nex*& e, lpvar j, vector<nex_sum*> & front) { bool split_with_var(nex*& e, lpvar j, vector<nex**> & front) {
SASSERT(e->is_sum()); SASSERT(e->is_sum());
TRACE("nla_cn", tout << "e = " << e << ", j=" << ch(j) << "\n";); TRACE("nla_cn", tout << "e = " << *e << ", j=" << ch(j) << "\n";);
nex_sum* a; nex * b; nex_sum* a; nex * b;
pre_split(to_sum(e), j, a, b); pre_split(to_sum(e), j, a, b);
/* /*

View file

@ -70,6 +70,11 @@ public:
virtual ~nex() {} virtual ~nex() {}
virtual bool contains(lpvar j) const { return false; } virtual bool contains(lpvar j) const { return false; }
virtual int get_degree() const = 0; virtual int get_degree() const = 0;
virtual void simplify() {}
virtual const vector<nex*> * children_ptr() const {
UNREACHABLE();
return nullptr;
}
}; };
std::ostream& operator<<(std::ostream& out, const nex&); std::ostream& operator<<(std::ostream& out, const nex&);
@ -107,6 +112,28 @@ public:
}; };
static void promote_children_by_type(vector<nex*> * children, expr_type t) {
svector<nex*> to_promote;
for(unsigned j = 0; j < children->size(); j++) {
nex* e = (*children)[j];
e->simplify();
if (e->type() == t) {
to_promote.push_back(e);
} else {
unsigned offset = to_promote.size();
if (offset) {
(*children)[j - offset] = e;
}
}
for (nex *e : to_promote) {
for (nex *ee : *(e->children_ptr())) {
children->push_back(ee);
}
}
}
}
class nex_mul : public nex { class nex_mul : public nex {
vector<nex*> m_children; vector<nex*> m_children;
public: public:
@ -115,6 +142,8 @@ public:
expr_type type() const { return expr_type::MUL; } expr_type type() const { return expr_type::MUL; }
vector<nex*>& children() { return m_children;} vector<nex*>& children() { return m_children;}
const vector<nex*>& children() const { return m_children;} const vector<nex*>& children() const { return m_children;}
const vector<nex*>* children_ptr() const { return &m_children;}
std::ostream & print(std::ostream& out) const { std::ostream & print(std::ostream& out) const {
bool first = true; bool first = true;
for (const nex* v : m_children) { for (const nex* v : m_children) {
@ -180,9 +209,13 @@ public:
return degree; return degree;
} }
void simplify() {
promote_children_by_type(&m_children, expr_type::MUL);
}
}; };
class nex_sum : public nex { class nex_sum : public nex {
vector<nex*> m_children; vector<nex*> m_children;
public: public:
@ -190,6 +223,7 @@ public:
expr_type type() const { return expr_type::SUM; } expr_type type() const { return expr_type::SUM; }
vector<nex*>& children() { return m_children;} vector<nex*>& children() { return m_children;}
const vector<nex*>& children() const { return m_children;} const vector<nex*>& children() const { return m_children;}
const vector<nex*>* children_ptr() const { return &m_children;}
unsigned size() const { return m_children.size(); } unsigned size() const { return m_children.size(); }
// we need a linear combination of at least two variables // we need a linear combination of at least two variables
@ -233,7 +267,7 @@ public:
} }
void simplify() { void simplify() {
SASSERT(false); promote_children_by_type(&m_children, expr_type::SUM);
} }
int get_degree() const { int get_degree() const {

View file

@ -86,7 +86,6 @@ void test_cn() {
nex_var* c = cn.mk_var(2); nex_var* c = cn.mk_var(2);
nex_var* d = cn.mk_var(3); nex_var* d = cn.mk_var(3);
nex_var* e = cn.mk_var(4); nex_var* e = cn.mk_var(4);
nex_var* f = cn.mk_var(5);
nex_var* g = cn.mk_var(6); nex_var* g = cn.mk_var(6);
nex* min_1 = cn.mk_scalar(rational(-1)); nex* min_1 = cn.mk_scalar(rational(-1));
// test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c); // test_cn_on_expr(min_1*c*e + min_1*b*d + min_1*a*b + a*c);
@ -94,8 +93,16 @@ void test_cn() {
nex_mul* bcg = cn.mk_mul(b, c, g); nex_mul* bcg = cn.mk_mul(b, c, g);
bcg->add_child(min_1); bcg->add_child(min_1);
nex_sum* t = cn.mk_sum(bcd, bcg); nex_sum* t = cn.mk_sum(bcd, bcg);
test_cn_on_expr(t, cn); // test_cn_on_expr(t, cn);
// test_cn_on_expr(a*a*d + a*b*c*d + a*a*c*c*d + a*d*d + e*a*e + e*a*c + e*d); nex* aad = cn.mk_mul(a, a, d);
nex* abcd = cn.mk_mul(a, b, c, d);
nex* aaccd = cn.mk_mul(a, a, c, c, d);
nex* add = cn.mk_mul(a, d, d);
nex* eae = cn.mk_mul(e, a, e);
nex* eac = cn.mk_mul(e, a, c);
nex* ed = cn.mk_mul(e, d);
test_cn_on_expr(cn.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn);
// TRACE("nla_cn", tout << "done\n";); // TRACE("nla_cn", tout << "done\n";);
// test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d); // test_cn_on_expr(a*b*d + a*b*c + c*b*d + a*c*d);
// TRACE("nla_cn", tout << "done\n";); // TRACE("nla_cn", tout << "done\n";);