mirror of
https://github.com/Z3Prover/z3
synced 2025-04-27 19:05:51 +00:00
rewrite horner scheme on top of nex_expr as a pointer
Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
parent
0f2c8c21ff
commit
9fbd0da931
7 changed files with 563 additions and 695 deletions
|
@ -22,7 +22,6 @@
|
|||
#include "math/lp/nla_expr.h"
|
||||
namespace nla {
|
||||
class cross_nested {
|
||||
typedef nla_expr<rational> nex;
|
||||
struct occ {
|
||||
unsigned m_occs;
|
||||
unsigned m_power;
|
||||
|
@ -36,61 +35,179 @@ class cross_nested {
|
|||
};
|
||||
|
||||
// fields
|
||||
nex& m_e;
|
||||
std::function<bool (const nex&)> m_call_on_result;
|
||||
nex_sum * m_e;
|
||||
std::function<bool (const nex*)> m_call_on_result;
|
||||
std::function<bool (unsigned)> m_var_is_fixed;
|
||||
bool m_done;
|
||||
std::unordered_map<lpvar, occ> m_occurences_map;
|
||||
std::unordered_map<lpvar, unsigned> m_powers;
|
||||
|
||||
vector<nex*> m_allocated;
|
||||
vector<nex*> m_b_vec;
|
||||
public:
|
||||
cross_nested(nex &e,
|
||||
std::function<bool (const nex&)> call_on_result,
|
||||
cross_nested(std::function<bool (const nex*)> call_on_result,
|
||||
std::function<bool (unsigned)> var_is_fixed):
|
||||
m_e(e),
|
||||
m_call_on_result(call_on_result),
|
||||
m_var_is_fixed(var_is_fixed),
|
||||
m_done(false)
|
||||
{}
|
||||
|
||||
void run() {
|
||||
vector<nex*> front;
|
||||
explore_expr_on_front_elem(&m_e, front); // true for trivial form - no change
|
||||
void run(nex_sum *e) {
|
||||
m_e = e;
|
||||
|
||||
vector<nex_sum*> front;
|
||||
explore_expr_on_front_elem(m_e, front);
|
||||
}
|
||||
|
||||
static nex* pop_back(vector<nex*>& front) {
|
||||
nex* c = front.back();
|
||||
static nex_sum* pop_back(vector<nex_sum*>& front) {
|
||||
nex_sum* c = front.back();
|
||||
TRACE("nla_cn", tout << *c << "\n";);
|
||||
front.pop_back();
|
||||
return c;
|
||||
}
|
||||
|
||||
static bool extract_common_factor(nex* c, nex& f, const vector<std::pair<lpvar, occ>> & occurences) {
|
||||
nex_sum* mk_sum() {
|
||||
auto r = new nex_sum();
|
||||
m_allocated.push_back(r);
|
||||
return r;
|
||||
}
|
||||
|
||||
nex_sum* mk_sum(const vector<nex*>& v) {
|
||||
auto r = new nex_sum();
|
||||
m_allocated.push_back(r);
|
||||
r->children() = v;
|
||||
return r;
|
||||
}
|
||||
|
||||
nex_sum* mk_sum(nex *a, nex* b) {
|
||||
auto r = new nex_sum();
|
||||
m_allocated.push_back(r);
|
||||
r->children().push_back(a);
|
||||
r->children().push_back(b);
|
||||
return r;
|
||||
}
|
||||
|
||||
nex_var* mk_var(lpvar j) {
|
||||
auto r = new nex_var(j);
|
||||
m_allocated.push_back(r);
|
||||
return r;
|
||||
}
|
||||
|
||||
nex_mul* mk_mul() {
|
||||
auto r = new nex_mul();
|
||||
m_allocated.push_back(r);
|
||||
return r;
|
||||
}
|
||||
|
||||
nex_mul* mk_mul(nex * a, nex * b) {
|
||||
auto r = new nex_mul();
|
||||
m_allocated.push_back(r);
|
||||
r->add_child(a); r->add_child(b);
|
||||
return r;
|
||||
}
|
||||
|
||||
nex_mul* mk_mul(nex * a, nex * b, nex *c) {
|
||||
auto r = new nex_mul();
|
||||
m_allocated.push_back(r);
|
||||
r->add_child(a); r->add_child(b); r->add_child(c);
|
||||
return r;
|
||||
}
|
||||
|
||||
nex_scalar* mk_scalar(const rational& v) {
|
||||
auto r = new nex_scalar(v);
|
||||
m_allocated.push_back(r);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
nex * mk_div(const nex* a, lpvar j) {
|
||||
SASSERT(false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
nex * mk_div(const nex* a, const nex* b) {
|
||||
TRACE("nla_cn_details", tout << *a <<" / " << *b << "\n";);
|
||||
if (b->is_var()) {
|
||||
return mk_div(a, to_var(b)->var());
|
||||
}
|
||||
SASSERT(b->is_mul());
|
||||
const nex_mul *bm = to_mul(b);
|
||||
if (a->is_sum()) {
|
||||
nex_sum * r = mk_sum();
|
||||
const nex_sum * m = to_sum(a);
|
||||
for (auto e : m->children()) {
|
||||
r->add_child(mk_div(e, bm));
|
||||
}
|
||||
TRACE("nla_cn_details", tout << *r << "\n";);
|
||||
return r;
|
||||
}
|
||||
if (a->is_var() || (a->is_mul() && to_mul(a)->children().size() == 1)) {
|
||||
return mk_scalar(rational(1));
|
||||
}
|
||||
SASSERT(a->is_mul());
|
||||
const nex_mul* am = to_mul(a);
|
||||
bm->get_powers_from_mul(m_powers);
|
||||
nex_mul* ret = new nex_mul();
|
||||
for (auto e : am->children()) {
|
||||
TRACE("nla_cn_details", tout << "e=" << *e << "\n";);
|
||||
|
||||
if (!e->is_var()) {
|
||||
SASSERT(e->is_scalar());
|
||||
ret->add_child(e);
|
||||
TRACE("nla_cn_details", tout << "continue\n";);
|
||||
continue;
|
||||
}
|
||||
SASSERT(e->is_var());
|
||||
lpvar j = to_var(e)->var();
|
||||
auto it = m_powers.find(j);
|
||||
if (it == m_powers.end()) {
|
||||
ret->add_child(e);
|
||||
} else {
|
||||
it->second --;
|
||||
if (it->second == 0)
|
||||
m_powers.erase(it);
|
||||
}
|
||||
TRACE("nla_cn_details", tout << *ret << "\n";);
|
||||
}
|
||||
SASSERT(m_powers.size() == 0);
|
||||
if (ret->children().size() == 0) {
|
||||
delete ret;
|
||||
TRACE("nla_cn_details", tout << "return 1\n";);
|
||||
return mk_scalar(rational(1));
|
||||
}
|
||||
m_allocated.push_back(ret);
|
||||
TRACE("nla_cn_details", tout << *ret << "\n";);
|
||||
return ret;
|
||||
}
|
||||
|
||||
nex* extract_common_factor(nex* e, const vector<std::pair<lpvar, occ>> & occurences) {
|
||||
nex_sum* c = to_sum(e);
|
||||
TRACE("nla_cn", tout << "c=" << *c << "\n";);
|
||||
SASSERT(c->is_sum());
|
||||
f.type() = expr_type::MUL;
|
||||
SASSERT(f.children().empty());
|
||||
unsigned size = c->children().size();
|
||||
for(const auto & p : occurences) {
|
||||
if (p.second.m_occs < size) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
nex_mul* f = mk_mul();
|
||||
for(const auto & p : occurences) { // randomize here: todo
|
||||
if (p.second.m_occs == size) {
|
||||
unsigned pow = p.second.m_power;
|
||||
while (pow --) {
|
||||
f *= nex::var(p.first);
|
||||
f->add_child(mk_var(p.first));
|
||||
}
|
||||
}
|
||||
}
|
||||
return !f.children().empty();
|
||||
return f;
|
||||
}
|
||||
|
||||
static bool has_common_factor(const nex& c) {
|
||||
TRACE("nla_cn", tout << "c=" << c << "\n";);
|
||||
SASSERT(c.is_sum());
|
||||
auto & ch = c.children();
|
||||
static bool has_common_factor(const nex_sum* c) {
|
||||
TRACE("nla_cn", tout << "c=" << *c << "\n";);
|
||||
auto & ch = c->children();
|
||||
auto common_vars = get_vars_of_expr(ch[0]);
|
||||
for (lpvar j : common_vars) {
|
||||
bool divides_the_rest = true;
|
||||
for(unsigned i = 1; i < ch.size() && divides_the_rest; i++) {
|
||||
if (!ch[i].contains(j))
|
||||
if (!ch[i]->contains(j))
|
||||
divides_the_rest = false;
|
||||
}
|
||||
if (divides_the_rest) {
|
||||
|
@ -101,45 +218,45 @@ public:
|
|||
return false;
|
||||
}
|
||||
|
||||
bool proceed_with_common_factor(nex* c, vector<nex*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
|
||||
bool proceed_with_common_factor(nex*& c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
|
||||
TRACE("nla_cn", tout << "c=" << *c << "\n";);
|
||||
SASSERT(c->is_sum());
|
||||
nex f;
|
||||
if (!extract_common_factor(c, f, occurences))
|
||||
nex* f = extract_common_factor(c, occurences);
|
||||
if (f == nullptr)
|
||||
return false;
|
||||
|
||||
*c /= f;
|
||||
f.simplify();
|
||||
* c = nex::mul(f, *c);
|
||||
TRACE("nla_cn", tout << "common factor=" << f << ", c=" << *c << "\n";);
|
||||
explore_expr_on_front_elem(&(c->children()[1]), front);
|
||||
nex_sum* c_over_f = to_sum(mk_div(c, f));
|
||||
c_over_f->simplify();
|
||||
c = mk_mul(f, c_over_f);
|
||||
TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << *c << "\ne = " << *m_e << "\n";);
|
||||
|
||||
explore_expr_on_front_elem(c_over_f, front);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void push(vector<nex*>& front, nex* e) {
|
||||
static void push(vector<nex_sum*>& front, nex_sum* e) {
|
||||
TRACE("nla_cn", tout << *e << "\n";);
|
||||
front.push_back(e);
|
||||
}
|
||||
|
||||
static vector<nex> copy_front(const vector<nex*>& front) {
|
||||
vector<nex> v;
|
||||
for (nex* n: front)
|
||||
v.push_back(*n);
|
||||
static vector<nex_sum*> copy_front(const vector<nex_sum*>& front) {
|
||||
vector<nex_sum*> v;
|
||||
for (nex_sum* n: front)
|
||||
v.push_back(n);
|
||||
return v;
|
||||
}
|
||||
|
||||
static void restore_front(const vector<nex> ©, vector<nex*>& front) {
|
||||
static void restore_front(const vector<nex_sum*> ©, vector<nex_sum*>& front) {
|
||||
SASSERT(copy.size() == front.size());
|
||||
for (unsigned i = 0; i < front.size(); i++)
|
||||
*(front[i]) = copy[i];
|
||||
front[i] = copy[i];
|
||||
}
|
||||
|
||||
void explore_expr_on_front_elem_occs(nex* c, vector<nex*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
|
||||
void explore_expr_on_front_elem_occs(nex* c, vector<nex_sum*>& front, const vector<std::pair<lpvar, occ>> & occurences) {
|
||||
if (proceed_with_common_factor(c, front, occurences))
|
||||
return;
|
||||
TRACE("nla_cn", tout << "save c=" << *c << "; front:"; print_vector_of_ptrs(front, tout) << "\n";);
|
||||
nex copy_of_c = *c;
|
||||
vector<nex> copy_of_front = copy_front(front);
|
||||
nex* copy_of_c = c;
|
||||
auto copy_of_front = copy_front(front);
|
||||
for(auto& p : occurences) {
|
||||
SASSERT(p.second.m_occs > 1);
|
||||
lpvar j = p.first;
|
||||
|
@ -152,7 +269,7 @@ public:
|
|||
explore_of_expr_on_sum_and_var(c, j, front);
|
||||
if (m_done)
|
||||
return;
|
||||
*c = copy_of_c;
|
||||
c = copy_of_c;
|
||||
TRACE("nla_cn", tout << "restore c=" << *c << ", m_e=" << m_e << "\n";);
|
||||
restore_front(copy_of_front, front);
|
||||
TRACE("nla_cn", tout << "restore c=" << *c << "\n";);
|
||||
|
@ -171,9 +288,8 @@ public:
|
|||
return out;
|
||||
}
|
||||
|
||||
void explore_expr_on_front_elem(nex* c, vector<nex*>& front) {
|
||||
SASSERT(c->is_sum());
|
||||
auto occurences = get_mult_occurences(*c);
|
||||
void explore_expr_on_front_elem(nex_sum* c, vector<nex_sum*>& front) {
|
||||
auto occurences = get_mult_occurences(c);
|
||||
TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << ", c occurences=";
|
||||
dump_occurences(tout, occurences) << "; front:"; print_vector_of_ptrs(front, tout) << "\n";);
|
||||
|
||||
|
@ -182,7 +298,7 @@ public:
|
|||
TRACE("nla_cn", tout << "got the cn form: =" << m_e << "\n";);
|
||||
m_done = m_call_on_result(m_e);
|
||||
} else {
|
||||
nex* c = pop_back(front);
|
||||
auto c = pop_back(front);
|
||||
explore_expr_on_front_elem(c, front);
|
||||
}
|
||||
} else {
|
||||
|
@ -196,17 +312,17 @@ public:
|
|||
// return (char)('a'+j);
|
||||
}
|
||||
// e is the global expression, c is the sub expressiond which is going to changed from sum to the cross nested form
|
||||
void explore_of_expr_on_sum_and_var(nex* c, lpvar j, vector<nex*> front) {
|
||||
void explore_of_expr_on_sum_and_var(nex* & c, lpvar j, vector<nex_sum*> front) {
|
||||
TRACE("nla_cn", tout << "m_e=" << m_e << "\nc=" << *c << "\nj = " << ch(j) << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";);
|
||||
if (!split_with_var(*c, j, front))
|
||||
if (!split_with_var(c, j, front))
|
||||
return;
|
||||
TRACE("nla_cn", tout << "after split c=" << *c << "\nfront="; print_vector_of_ptrs(front, tout) << "\n";);
|
||||
SASSERT(front.size());
|
||||
nex* n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";);
|
||||
auto n = pop_back(front); TRACE("nla_cn", tout << "n=" << *n <<"\n";);
|
||||
explore_expr_on_front_elem(n, front);
|
||||
}
|
||||
|
||||
void process_var_occurences(lpvar j) {
|
||||
void add_var_occs(lpvar j) {
|
||||
auto it = m_occurences_map.find(j);
|
||||
if (it != m_occurences_map.end()) {
|
||||
it->second.m_occs++;
|
||||
|
@ -251,15 +367,14 @@ public:
|
|||
|
||||
// j -> the number of expressions j appears in as a multiplier
|
||||
// The result is sorted by large number of occurences first
|
||||
vector<std::pair<lpvar, occ>> get_mult_occurences(const nex& e) {
|
||||
vector<std::pair<lpvar, occ>> get_mult_occurences(const nex_sum* e) {
|
||||
clear_maps();
|
||||
SASSERT(e.type() == expr_type::SUM);
|
||||
for (const auto & ce : e.children()) {
|
||||
if (ce.is_mul()) {
|
||||
auto powers = ce.get_powers_from_mul();
|
||||
for (const auto * ce : e->children()) {
|
||||
if (ce->is_mul()) {
|
||||
to_mul(ce)->get_powers_from_mul(m_powers);
|
||||
update_occurences_with_powers();
|
||||
} else if (ce.type() == expr_type::VAR) {
|
||||
process_var_occurences(ce.var());
|
||||
} else if (ce->is_var()) {
|
||||
add_var_occs(to_var(ce)->var());
|
||||
}
|
||||
}
|
||||
remove_singular_occurences();
|
||||
|
@ -281,63 +396,65 @@ public:
|
|||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
static bool is_divisible_by_var(nex* ce, lpvar j) {
|
||||
return (ce->is_mul() && to_mul(ce)->contains(j))
|
||||
|| (ce->is_var() && to_var(ce)->var() == j);
|
||||
}
|
||||
// all factors of j go to a, the rest to b
|
||||
static void pre_split(nex &e, lpvar j, nex &a, nex&b) {
|
||||
for (const nex & ce: e.children()) {
|
||||
if ((ce.is_mul() && ce.contains(j)) || (ce.is_var() && ce.var() == j)) {
|
||||
a.add_child(ce / j);
|
||||
void pre_split(nex_sum * e, lpvar j, nex_sum* & a, nex* & b) {
|
||||
|
||||
a = mk_sum();
|
||||
m_b_vec.clear();
|
||||
for (nex * ce: e->children()) {
|
||||
if (is_divisible_by_var(ce, j)) {
|
||||
a->add_child(mk_div(ce , j));
|
||||
} else {
|
||||
b.add_child(ce);
|
||||
m_b_vec.push_back(ce);
|
||||
}
|
||||
}
|
||||
a.type() = expr_type::SUM;
|
||||
TRACE("nla_cn_details", tout << "a = " << a << "\n";);
|
||||
SASSERT(a.children().size() >= 2);
|
||||
a.simplify();
|
||||
TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
|
||||
SASSERT(a->children().size() >= 2 && m_b_vec.size());
|
||||
a->simplify();
|
||||
|
||||
if (b.children().size() == 1) {
|
||||
nex t = b.children()[0];
|
||||
b = t;
|
||||
} else if (b.children().size() > 1) {
|
||||
b.type() = expr_type::SUM;
|
||||
}
|
||||
if (m_b_vec.size() == 1) {
|
||||
b = m_b_vec[0];
|
||||
} else {
|
||||
SASSERT(m_b_vec.size() > 1);
|
||||
b = mk_sum(m_b_vec);
|
||||
}
|
||||
}
|
||||
|
||||
// returns true if the recursion is done inside
|
||||
void update_front_with_split_with_non_empty_b(nex& e, lpvar j, vector<nex*> & front, nex& a, nex& b) {
|
||||
nex f;
|
||||
SASSERT(a.is_sum());
|
||||
void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) {
|
||||
|
||||
SASSERT(a->is_sum());
|
||||
|
||||
TRACE("nla_cn_details", tout << "b = " << b << "\n";);
|
||||
e = nex::sum(nex::mul(nex::var(j), a), b);
|
||||
push(front, &(e.children()[0].children()[1])); // pushing 'a'
|
||||
TRACE("nla_cn", tout << "push to front " << e.children()[0].children()[1] << "\n";);
|
||||
e = mk_sum(mk_mul(mk_var(j), a), b); // e = j*a + b
|
||||
push(front, a); // pushing 'a'
|
||||
TRACE("nla_cn", tout << "push to front " << *a << "\n";);
|
||||
|
||||
if (b.is_sum()) {
|
||||
push(front, &(e.children()[1]));
|
||||
TRACE("nla_cn", tout << "push to front " << e.children()[1] << "\n";);
|
||||
if (b->is_sum()) {
|
||||
push(front, to_sum(b));
|
||||
TRACE("nla_cn", tout << "push to front " << *b << "\n";);
|
||||
}
|
||||
}
|
||||
|
||||
void update_front_with_split(nex& e, lpvar j, vector<nex*> & front, nex& a, nex& b) {
|
||||
if (b.is_undef()) {
|
||||
SASSERT(b.children().size() == 0);
|
||||
e = nex(expr_type::MUL);
|
||||
e.add_child(nex::var(j));
|
||||
e.add_child(a);
|
||||
if (a.size() > 1) {
|
||||
push(front, &e.children().back());
|
||||
TRACE("nla_cn_details", tout << "push to front " << e.children().back() << "\n";);
|
||||
}
|
||||
void update_front_with_split(nex* & e, lpvar j, vector<nex_sum*> & front, nex_sum* a, nex* b) {
|
||||
if (b == nullptr) {
|
||||
e = mk_mul(mk_var(j), a);
|
||||
push(front, a);
|
||||
TRACE("nla_cn_details", tout << "push to front " << *a << "\n";);
|
||||
} else {
|
||||
update_front_with_split_with_non_empty_b(e, j, front, a, b);
|
||||
}
|
||||
update_front_with_split_with_non_empty_b(e, j, front, a, b);
|
||||
}
|
||||
// it returns true if the recursion brings a cross-nested form
|
||||
bool split_with_var(nex& e, lpvar j, vector<nex*> & front) {
|
||||
bool split_with_var(nex*& e, lpvar j, vector<nex_sum*> & front) {
|
||||
SASSERT(e->is_sum());
|
||||
TRACE("nla_cn", tout << "e = " << e << ", j=" << ch(j) << "\n";);
|
||||
if (!e.is_sum()) return false;
|
||||
nex a, b;
|
||||
pre_split(e, j, a, b);
|
||||
nex_sum* a; nex * b;
|
||||
pre_split(to_sum(e), j, a, b);
|
||||
/*
|
||||
When we have e without a non-trivial common factor then
|
||||
there is a variable j such that e = jP + Q, where Q has all members
|
||||
|
@ -352,28 +469,42 @@ public:
|
|||
update_front_with_split(e, j, front, a, b);
|
||||
return true;
|
||||
}
|
||||
static std::unordered_set<lpvar> get_vars_of_expr(const nex &e ) {
|
||||
|
||||
static std::unordered_set<lpvar> get_vars_of_expr(const nex *e ) {
|
||||
std::unordered_set<lpvar> r;
|
||||
switch (e.type()) {
|
||||
switch (e->type()) {
|
||||
case expr_type::SCALAR:
|
||||
return r;
|
||||
case expr_type::SUM:
|
||||
{
|
||||
for (auto c: to_sum(e)->children())
|
||||
for ( lpvar j : get_vars_of_expr(c))
|
||||
r.insert(j);
|
||||
}
|
||||
case expr_type::MUL:
|
||||
{
|
||||
for (const auto & c: e.children())
|
||||
for (auto c: to_mul(e)->children())
|
||||
for ( lpvar j : get_vars_of_expr(c))
|
||||
r.insert(j);
|
||||
}
|
||||
return r;
|
||||
case expr_type::VAR:
|
||||
r.insert(e.var());
|
||||
r.insert(to_var(e)->var());
|
||||
return r;
|
||||
default:
|
||||
TRACE("nla_cn_details", tout << e.type() << "\n";);
|
||||
TRACE("nla_cn_details", tout << e->type() << "\n";);
|
||||
SASSERT(false);
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
~cross_nested() {
|
||||
for (auto e: m_allocated)
|
||||
delete e;
|
||||
m_allocated.clear();
|
||||
}
|
||||
|
||||
bool done() const { return m_done; }
|
||||
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue