3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-27 19:05:51 +00:00

rewrite horner scheme on top of nex_expr as a pointer

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
Lev Nachmanson 2019-08-15 17:15:45 -07:00
parent 0f2c8c21ff
commit 9fbd0da931
7 changed files with 563 additions and 695 deletions

View file

@ -22,7 +22,6 @@
#include "math/lp/nla_expr.h"
namespace nla {
class cross_nested {
typedef nla_expr<rational> nex;
struct occ {
unsigned m_occs;
unsigned m_power;
@ -36,61 +35,179 @@ class cross_nested {
};
// fields
nex& m_e;
std::function<bool (const nex&)> m_call_on_result;
nex_sum * m_e;
std::function<bool (const nex*)> m_call_on_result;
std::function<bool (unsigned)> m_var_is_fixed;
bool m_done;
std::unordered_map<lpvar, occ> m_occurences_map;
std::unordered_map<lpvar, unsigned> m_powers;
vector<nex*> m_allocated;
vector<nex*> m_b_vec;
public:
cross_nested(nex &e,
std::function<bool (const nex&)> call_on_result,
cross_nested(std::function<bool (const nex*)> call_on_result,
std::function<bool (unsigned)> var_is_fixed):
m_e(e),
m_call_on_result(call_on_result),
m_var_is_fixed(var_is_fixed),
m_done(false)
{}
void run() {
vector<nex*> front;
explore_expr_on_front_elem(&m_e, front); // true for trivial form - no change
void run(nex_sum *e) {
m_e = e;
vector<nex_sum*> front;
explore_expr_on_front_elem(m_e, front);
}
static nex* pop_back(vector<nex*>& front) {
nex* c = front.back();
static nex_sum* pop_back(vector<nex_sum*>& front) {
nex_sum* c = front.back();
TRACE("nla_cn", tout << *c << "\n";);
front.pop_back();
return c;
}
static bool extract_common_factor(nex* c, nex& f, const vector<std::pair<lpvar, occ>> & occurences) {
nex_sum* mk_sum() {
auto r = new nex_sum();
m_allocated.push_back(r);
return r;
}
nex_sum* mk_sum(const vector<nex*>& v) {
auto r = new nex_sum();
m_allocated.push_back(r);
r->children() = v;
return r;
}
nex_sum* mk_sum(nex *a, nex* b) {
auto r = new nex_sum();
m_allocated.push_back(r);
r->children().push_back(a);
r->children().push_back(b);
return r;
}
nex_var* mk_var(lpvar j) {
auto r = new nex_var(j);
m_allocated.push_back(r);
return r;
}
nex_mul* mk_mul() {
auto r = new nex_mul();
m_allocated.push_back(r);
return r;
}
nex_mul* mk_mul(nex * a, nex * b) {
auto r = new nex_mul();
m_allocated.push_back(r);
r->add_child(a); r->add_child(b);
return r;
}
nex_mul* mk_mul(nex * a, nex * b, nex *c) {
auto r = new nex_mul();
m_allocated.push_back(r);
r->add_child(a); r->add_child(b); r->add_child(c);
return r;
}
nex_scalar* mk_scalar(const rational& v) {
auto r = new nex_scalar(v);
m_allocated.push_back(r);
return r;
}
nex * mk_div(const nex* a, lpvar j) {
SASSERT(false);
return nullptr;
}
nex * mk_div(const nex* a, const nex* b) {
TRACE("nla_cn_details", tout << *a <<" / " << *b << "\n";);
if (b->is_var()) {
return mk_div(a, to_var(b)->var());
}
SASSERT(b->is_mul());
const nex_mul *bm = to_mul(b);
if (a->is_sum()) {
nex_sum * r = mk_sum();
const nex_sum * m = to_sum(a);
for (auto e : m->children()) {
r->add_child(mk_div(e, bm));
}
TRACE("nla_cn_details", tout << *r << "\n";);
return r;
}
if (a->is_var() || (a->is_mul() && to_mul(a)->children().size() == 1)) {
return mk_scalar(rational(1));
}
SASSERT(a->is_mul());
const nex_mul* am = to_mul(a);
bm->get_powers_from_mul(m_powers);
nex_mul* ret = new nex_mul();
for (auto e : am->children()) {
TRACE("nla_cn_details", tout << "e=" << *e << "\n";);
if (!e->is_var()) {
SASSERT(e->is_scalar());
ret->add_child(e);
TRACE("nla_cn_details", tout << "continue\n";);
continue;
}
SASSERT(e->is_var());
lpvar j = to_var(e)->var();
auto it = m_powers.find(j);
if (it == m_powers.end()) {
ret->add_child(e);
} else {
it->second --;
if (it->second == 0)
m_powers.erase(it);
}
TRACE("nla_cn_details", tout << *ret << "\n";);
}
SASSERT(m_powers.size() == 0);
if (ret->children().size() == 0) {
delete ret;
TRACE("nla_cn_details", tout << "return 1\n";);
return mk_scalar(rational(1));
}
m_allocated.push_back(ret);
TRACE("nla_cn_details", tout << *ret << "\n";);
return ret;
}
nex* extract_common_factor(nex* e, const vector<std::pair<lpvar, occ>> & occurences) {
nex_sum* c = to_sum(e);
TRACE("nla_cn", tout << "c=" << *c << "\n";);
SASSERT(c->is_sum());
f.type() = expr_type::MUL;
SASSERT(f.children().empty());
unsigned size = c->children().size();
for(const auto & p : occurences) {
if (p.second.m_occs < size) {
return nullptr;
}
}
nex_mul* f = mk_mul();
for(const auto & p : occurences) { // randomize here: todo
if (p.second.m_occs == size) {
unsigned pow = p.second.m_power;
while (pow --) {
f *= nex::var(p.first);
f->add_child(mk_var(p.first));
}
}
}
return !f.children().empty();
return f;
}
static bool has_common_factor(const nex& c) {
TRACE("nla_cn", tout << "c=" << c << "\n";);
SASSERT(c.is_sum());
auto & ch = c.children();
static bool has_common_factor(const nex_sum* c) {
TRACE("nla_cn", tout << "c=" << *c << "\n";);
auto & ch = c->children();
auto common_vars = get_vars_of_expr(ch[0]);
for (lpvar j : common_vars) {
bool divides_the_rest = true;
for(unsigned i = 1; i < ch.size() && divides_the_rest; i++) {
if (!ch[i].contains(j))
if (!ch[i]->contains(j))
divides_the_rest = false;
}
if (divides_the_rest) {
@ -101,45 +218,45 @@ public:
return false;
}
bool proceed_with_common_factor(nex* c, vector<nex*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
bool proceed_with_common_factor(nex*& c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
TRACE("nla_cn", tout << "c=" << *c << "\n";);
SASSERT(c->is_sum());
nex f;
if (!extract_common_factor(c, f, occurences))
nex* f = extract_common_factor(c, occurences);
if (f == nullptr)
return false;
*c /= f;
f.simplify();
* c = nex::mul(f, *c);
TRACE("nla_cn", tout << "common factor=" << f << ", c=" << *c << "\n";);
explore_expr_on_front_elem(&(c->children()[1]), front);
nex_sum* c_over_f = to_sum(mk_div(c, f));
c_over_f->simplify();
c = mk_mul(f, c_over_f);
TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << *c << "\ne = " << *m_e << "\n";);
explore_expr_on_front_elem(c_over_f, front);
return true;
}
static void push(vector<nex*>& front, nex* e) {
static void push(vector<nex_sum*>& front, nex_sum* e) {
TRACE("nla_cn", tout << *e << "\n";);
front.push_back(e);
}
static vector<nex> copy_front(const vector<nex*>& front) {
vector<nex> v;
for (nex* n: front)
v.push_back(*n);
static vector<nex_sum*> copy_front(const vector<nex_sum*>& front) {
vector<nex_sum*> v;
for (nex_sum* n: front)
v.push_back(n);
return v;
}
static void restore_front(const vector<nex> &copy, vector<nex*>& front) {
static void restore_front(const vector<nex_sum*> &copy, vector<nex_sum*>& front) {
SASSERT(copy.size() == front.size());
for (unsigned i = 0; i < front.size(); i++)
*(front[i]) = copy[i];
front[i] = copy[i];
}
void explore_expr_on_front_elem_occs(nex* c, vector<nex*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
void explore_expr_on_front_elem_occs(nex* c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
if (proceed_with_common_factor(c, front, occurences))
return;
TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_vector_of_ptrs(front, tout) << "\n";);
nex copy_of_c = *c;
vector<nex> copy_of_front = copy_front(front);
nex* copy_of_c = c;
auto copy_of_front = copy_front(front);
for(auto& p : occurences) {
SASSERT(p.second.m_occs > 1);
lpvar j = p.first;
@ -152,7 +269,7 @@ public:
explore_of_expr_on_sum_and_var(c, j, front);
if (m_done)
return;
*c = copy_of_c;
c = copy_of_c;
TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";);
restore_front(copy_of_front, front);
TRACE("nla_cn", tout << "restore c=" << *c << "\n";);
@ -171,9 +288,8 @@ public:
return out;
}
void explore_expr_on_front_elem(nex* c, vector<nex*>& front) {
SASSERT(c->is_sum());
auto occurences = get_mult_occurences(*c);
void explore_expr_on_front_elem(nex_sum* c, vector<nex_sum*>& front) {
auto occurences = get_mult_occurences(c);
TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << ", c occurences=";
dump_occurences(tout, occurences) << "; front:"; print_vector_of_ptrs(front, tout) << "\n";);
@ -182,7 +298,7 @@ public:
TRACE("nla_cn", tout << "got the cn form: =" << m_e << "\n";);
m_done = m_call_on_result(m_e);
} else {
nex* c = pop_back(front);
auto c = pop_back(front);
explore_expr_on_front_elem(c, front);
}
} else {
@ -196,17 +312,17 @@ public:
// return (char)('a'+j);
}
// e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form
void explore_of_expr_on_sum_and_var(nex* c, lpvar j, vector<nex*> front) {
void explore_of_expr_on_sum_and_var(nex* & c, lpvar j, vector<nex_sum*> front) {
TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";);
if (!split_with_var(*c, j, front))
if (!split_with_var(c, j, front))
return;
TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";);
SASSERT(front.size());
nex* n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";);
auto n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";);
explore_expr_on_front_elem(n, front);
}
void process_var_occurences(lpvar j) {
void add_var_occs(lpvar j) {
auto it = m_occurences_map.find(j);
if (it != m_occurences_map.end()) {
it->second.m_occs++;
@ -251,15 +367,14 @@ public:
// j -> the number of expressions j appears in as a multiplier
// The result is sorted by large number of occurences first
vector<std::pair<lpvar, occ>> get_mult_occurences(const nex& e) {
vector<std::pair<lpvar, occ>> get_mult_occurences(const nex_sum* e) {
clear_maps();
SASSERT(e.type() == expr_type::SUM);
for (const auto & ce : e.children()) {
if (ce.is_mul()) {
auto powers = ce.get_powers_from_mul();
for (const auto * ce : e->children()) {
if (ce->is_mul()) {
to_mul(ce)->get_powers_from_mul(m_powers);
update_occurences_with_powers();
} else if (ce.type() == expr_type::VAR) {
process_var_occurences(ce.var());
} else if (ce->is_var()) {
add_var_occs(to_var(ce)->var());
}
}
remove_singular_occurences();
@ -281,63 +396,65 @@ public:
});
return ret;
}
static bool is_divisible_by_var(nex* ce, lpvar j) {
return (ce->is_mul() && to_mul(ce)->contains(j))
|| (ce->is_var() && to_var(ce)->var() == j);
}
// all factors of j go to a, the rest to b
static void pre_split(nex &e, lpvar j, nex &a, nex&b) {
for (const nex & ce: e.children()) {
if ((ce.is_mul() && ce.contains(j)) || (ce.is_var() && ce.var() == j)) {
a.add_child(ce / j);
void pre_split(nex_sum * e, lpvar j, nex_sum* & a, nex* & b) {
a = mk_sum();
m_b_vec.clear();
for (nex * ce: e->children()) {
if (is_divisible_by_var(ce, j)) {
a->add_child(mk_div(ce , j));
} else {
b.add_child(ce);
m_b_vec.push_back(ce);
}
}
a.type() = expr_type::SUM;
TRACE("nla_cn_details", tout << "a = " << a << "\n";);
SASSERT(a.children().size() >= 2);
a.simplify();
TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
SASSERT(a->children().size() >= 2 && m_b_vec.size());
a->simplify();
if (b.children().size() == 1) {
nex t = b.children()[0];
b = t;
} else if (b.children().size() > 1) {
b.type() = expr_type::SUM;
}
if (m_b_vec.size() == 1) {
b = m_b_vec[0];
} else {
SASSERT(m_b_vec.size() > 1);
b = mk_sum(m_b_vec);
}
}
// returns true if the recursion is done inside
void update_front_with_split_with_non_empty_b(nex& e, lpvar j, vector<nex*> & front, nex& a, nex& b) {
nex f;
SASSERT(a.is_sum());
void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) {
SASSERT(a->is_sum());
TRACE("nla_cn_details", tout << "b = " << b << "\n";);
e = nex::sum(nex::mul(nex::var(j), a), b);
push(front, &(e.children()[0].children()[1])); // pushing 'a'
TRACE("nla_cn", tout << "push to front " << e.children()[0].children()[1] << "\n";);
e = mk_sum(mk_mul(mk_var(j), a), b); // e = j*a + b
push(front, a); // pushing 'a'
TRACE("nla_cn", tout << "push to front " << *a << "\n";);
if (b.is_sum()) {
push(front, &(e.children()[1]));
TRACE("nla_cn", tout << "push to front " << e.children()[1] << "\n";);
if (b->is_sum()) {
push(front, to_sum(b));
TRACE("nla_cn", tout << "push to front " << *b << "\n";);
}
}
void update_front_with_split(nex& e, lpvar j, vector<nex*> & front, nex& a, nex& b) {
if (b.is_undef()) {
SASSERT(b.children().size() == 0);
e = nex(expr_type::MUL);
e.add_child(nex::var(j));
e.add_child(a);
if (a.size() > 1) {
push(front, &e.children().back());
TRACE("nla_cn_details", tout << "push to front " << e.children().back() << "\n";);
}
void update_front_with_split(nex* & e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) {
if (b == nullptr) {
e = mk_mul(mk_var(j), a);
push(front, a);
TRACE("nla_cn_details", tout << "push to front " << *a << "\n";);
} else {
update_front_with_split_with_non_empty_b(e, j, front, a, b);
}
update_front_with_split_with_non_empty_b(e, j, front, a, b);
}
// it returns true if the recursion brings a cross-nested form
bool split_with_var(nex& e, lpvar j, vector<nex*> & front) {
bool split_with_var(nex*& e, lpvar j, vector<nex_sum*> & front) {
SASSERT(e->is_sum());
TRACE("nla_cn", tout << "e = " << e << ", j=" << ch(j) << "\n";);
if (!e.is_sum()) return false;
nex a, b;
pre_split(e, j, a, b);
nex_sum* a; nex * b;
pre_split(to_sum(e), j, a, b);
/*
When we have e without a non-trivial common factor then
there is a variable j such that e = jP + Q, where Q has all members
@ -352,28 +469,42 @@ public:
update_front_with_split(e, j, front, a, b);
return true;
}
static std::unordered_set<lpvar> get_vars_of_expr(const nex &e ) {
static std::unordered_set<lpvar> get_vars_of_expr(const nex *e ) {
std::unordered_set<lpvar> r;
switch (e.type()) {
switch (e->type()) {
case expr_type::SCALAR:
return r;
case expr_type::SUM:
{
for (auto c: to_sum(e)->children())
for ( lpvar j : get_vars_of_expr(c))
r.insert(j);
}
case expr_type::MUL:
{
for (const auto & c: e.children())
for (auto c: to_mul(e)->children())
for ( lpvar j : get_vars_of_expr(c))
r.insert(j);
}
return r;
case expr_type::VAR:
r.insert(e.var());
r.insert(to_var(e)->var());
return r;
default:
TRACE("nla_cn_details", tout << e.type() << "\n";);
TRACE("nla_cn_details", tout << e->type() << "\n";);
SASSERT(false);
return r;
}
}
~cross_nested() {
for (auto e: m_allocated)
delete e;
m_allocated.clear();
}
bool done() const { return m_done; }
};
}