3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-29 11:55:51 +00:00

rewrite horner scheme on top of nex_expr as a pointer

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
Lev Nachmanson 2019-08-15 17:15:45 -07:00
parent 0f2c8c21ff
commit 9fbd0da931
7 changed files with 563 additions and 695 deletions

View file

@ -47,38 +47,178 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) {
// This class is needed in horner calculation with intervals
template <typename T>
class nla_expr {
// todo: use union
expr_type m_type;
lpvar m_j;
T m_v; // for the scalar
vector<nla_expr> m_children;
class nex {
public:
bool is_sum() const { return m_type == expr_type::SUM; }
bool is_var() const { return m_type == expr_type::VAR; }
bool is_mul() const { return m_type == expr_type::MUL; }
bool is_undef() const { return m_type == expr_type::UNDEF; }
bool is_scalar() const { return m_type == expr_type::SCALAR; }
lpvar var() const { SASSERT(m_type == expr_type::VAR); return m_j; }
expr_type type() const { return m_type; }
expr_type& type() { return m_type; }
const vector<nla_expr>& children() const { return m_children; }
vector<nla_expr>& children() { return m_children; }
const T& value() const { SASSERT(m_type == expr_type::SCALAR); return m_v; }
std::string str() const { std::stringstream ss; ss << *this; return ss.str(); }
std::ostream & print_sum(std::ostream& out) const {
virtual expr_type type() const = 0;
virtual std::ostream& print(std::ostream&) const = 0;
nex() {}
bool is_simple() const {
switch(type()) {
case expr_type::SUM:
case expr_type::MUL:
return false;
default:
return true;
}
}
bool is_sum() const { return type() == expr_type::SUM; }
bool is_mul() const { return type() == expr_type::MUL; }
bool is_var() const { return type() == expr_type::VAR; }
bool is_scalar() const { return type() == expr_type::SCALAR; }
std::string str() const { std::stringstream ss; print(ss); return ss.str(); }
virtual ~nex() {}
virtual bool contains(lpvar j) const { return false; }
virtual int get_degree() const = 0;
};
std::ostream& operator<<(std::ostream& out, const nex&);
class nex_var : public nex {
lpvar m_j;
public:
nex_var(lpvar j) : m_j(j) {}
nex_var() {}
expr_type type() const { return expr_type::VAR; }
lpvar var() const { return m_j; }
lpvar& var() { return m_j; } // the setter
std::ostream & print(std::ostream& out) const {
out << 'v' << m_j;
return out;
}
bool contains(lpvar j) const { return j == m_j; }
int get_degree() const { return 1; }
};
class nex_scalar : public nex {
rational m_v;
public:
nex_scalar(const rational& v) : m_v(v) {}
nex_scalar() {}
expr_type type() const { return expr_type::SCALAR; }
const rational& value() const { return m_v; }
rational& value() { return m_v; } // the setter
std::ostream& print(std::ostream& out) const {
out << m_v;
return out;
}
int get_degree() const { return 0; }
};
class nex_mul : public nex {
vector<nex*> m_children;
public:
nex_mul() {}
unsigned size() const { return m_children.size(); }
expr_type type() const { return expr_type::MUL; }
vector<nex*>& children() { return m_children;}
const vector<nex*>& children() const { return m_children;}
std::ostream & print(std::ostream& out) const {
bool first = true;
for (const nla_expr& v : m_children) {
std::string s = v.str();
for (const nex* v : m_children) {
std::string s = v->str();
if (first) {
first = false;
if (v.is_simple())
out << v;
if (v->is_simple())
out << s;
else
out << "(" << s << ")";
} else {
if (v.is_simple()) {
if (v->is_simple()) {
if (s[0] == '-') {
out << "*(" << s << ")";
} else {
out << "*" << s;
}
} else {
out << "*(" << s << ")";
}
}
}
return out;
}
void add_child(nex* e) { m_children.push_back(e); }
bool contains(lpvar j) const {
for (const nex* c : children()) {
if (c->contains(j))
return true;
}
return false;
}
static const nex_var* to_var(const nex*a) {
SASSERT(a->is_var());
return static_cast<const nex_var*>(a);
}
void get_powers_from_mul(std::unordered_map<lpvar, unsigned> & r) const {
r.clear();
for (const auto & c : children()) {
if (!c->is_var()) {
continue;
}
lpvar j = to_var(c)->var();
auto it = r.find(j);
if (it == r.end()) {
r[j] = 1;
} else {
it->second++;
}
}
TRACE("nla_cn_details", tout << "powers of " << *this << "\n"; print_vector(r, tout)<< "\n";);
}
int get_degree() const {
int degree = 0;
for (auto e : children()) {
degree += e->get_degree();
}
return degree;
}
};
class nex_sum : public nex {
vector<nex*> m_children;
public:
nex_sum() {}
expr_type type() const { return expr_type::SUM; }
vector<nex*>& children() { return m_children;}
const vector<nex*>& children() const { return m_children;}
unsigned size() const { return m_children.size(); }
// we need a linear combination of at least two variables
bool is_a_linear_term() const {
TRACE("nex_details", tout << *this << "\n";);
unsigned number_of_non_scalars = 0;
for (auto e : children()) {
int d = e->get_degree();
if (d == 0) continue;
if (d > 1) return false;
number_of_non_scalars++;
}
TRACE("nex_details", tout << (number_of_non_scalars > 1?"linear":"non-linear") << "\n";);
return number_of_non_scalars > 1;
}
std::ostream & print(std::ostream& out) const {
bool first = true;
for (const nex* v : m_children) {
std::string s = v->str();
if (first) {
first = false;
if (v->is_simple())
out << s;
else
out << "(" << s << ")";
} else {
if (v->is_simple()) {
if (s[0] == '-') {
out << s;
} else {
@ -93,457 +233,52 @@ public:
}
void simplify() {
if (is_simple()) return;
bool has_sum = false;
if (is_sum()) {
for (auto & e : m_children) {
e.simplify();
has_sum |= e.is_sum();
}
if (has_sum) {
nla_expr n(expr_type::SUM);
for (auto &e : m_children) {
n += e;
}
m_children = n.m_children;
}
} else if (is_mul()) {
bool has_mul = false;
for (auto & e : m_children) {
e.simplify();
has_mul |= e.is_mul();
}
if (has_mul) {
nla_expr n(expr_type::MUL);
for (auto &e : m_children) {
n *= e;
}
m_children = n.m_children;
}
TRACE("nla_cn_details", tout << "simplified " << *this << "\n";);
}
}
std::ostream & print_mul(std::ostream& out) const {
bool first = true;
for (const nla_expr& v : m_children) {
std::string s = v.str();
if (first) {
first = false;
if (v.is_simple())
out << s;
else
out << "(" << s << ")";
} else {
if (v.is_simple()) {
if (s[0] == '-') {
out << "*(" << s << ")";
} else {
out << "*" << s;
}
} else {
out << "*(" << s << ")";
}
}
}
return out;
}
std::ostream & print(std::ostream& out) const {
switch(m_type) {
case expr_type::SUM:
return print_sum(out);
case expr_type::MUL:
return print_mul(out);
case expr_type::VAR:
out << 'v' << m_j;
return out;
case expr_type::SCALAR:
out << m_v;
return out;
default:
out << "undef";
return out;
}
}
bool is_simple() const {
switch(m_type) {
case expr_type::SUM:
case expr_type::MUL:
return false;
default:
return true;
}
}
unsigned size() const {
switch(m_type) {
case expr_type::SUM:
case expr_type::MUL:
return m_children.size();
default:
return 1;
}
}
nla_expr(expr_type t): m_type(t) {}
nla_expr(): m_type(expr_type::UNDEF) {}
void add_child(const nla_expr& e) {
m_children.push_back(e);
}
void add_child(const T& k) {
m_children.push_back(scalar(k));
}
void add_children() { }
template <typename K, typename ...Args>
void add_children(K e, Args ... es) {
add_child(e);
add_children(es ...);
}
template <typename K, typename ... Args>
static nla_expr sum(K e, Args ... es) {
nla_expr r(expr_type::SUM);
r.add_children(e, es...);
return r;
}
template <typename K, typename ... Args>
static nla_expr mul(K e, Args ... es) {
nla_expr r(expr_type::MUL);
r.add_children(e, es...);
return r;
}
static nla_expr mul(const T& v, nla_expr & w) {
if (v == 1)
return w;
nla_expr r(expr_type::MUL);
r.add_child(scalar(v));
r.add_child(w);
return r;
}
static nla_expr mul() {
return nla_expr(expr_type::MUL);
}
static nla_expr mul(const T& v, lpvar j) {
if (v == 1)
return var(j);
return mul(scalar(v), var(j));
}
static nla_expr scalar(const T& v) {
nla_expr r(expr_type::SCALAR);
r.m_v = v;
return r;
}
static nla_expr var(lpvar j) {
nla_expr r(expr_type::VAR);
r.m_j = j;
return r;
}
bool contains(lpvar j) const {
if (is_var())
return m_j == j;
if (is_mul()) {
for (const nla_expr<T>& c : children()) {
if (c.contains(j))
return true;
}
}
return false;
}
nla_expr& operator*=(const nla_expr& b) {
if (is_mul()) {
if (b.is_mul()) {
for (auto& e: b.children())
add_child(e);
} else {
add_child(b);
}
return *this;
}
SASSERT(false); // not impl
return *this;
}
std::unordered_map<lpvar, int> get_powers_from_mul() const {
SASSERT(is_mul());
std::unordered_map<lpvar, int> r;
for (const auto & c : children()) {
if (!c.is_var()) {
continue;
}
lpvar j = c.var();
auto it = r.find(j);
if (it == r.end()) {
r[j] = 1;
} else {
it->second++;
}
}
TRACE("nla_cn_details", tout << "powers of " << *this << "\n"; print_vector(r, tout)<< "\n";);
return r;
}
friend nla_expr operator-(const nla_expr& a, const nla_expr&b) {
return a + scalar(T(-1))*b;
SASSERT(false);
}
nla_expr& operator/=(const nla_expr& b) {
TRACE("nla_cn_details", tout << *this <<" / " << b << "\n";);
if (b.is_var()) {
*this = (*this) / b.var();
TRACE("nla_cn_details", tout << *this << "\n";);
return *this;
}
SASSERT(b.is_mul());
if (is_sum()) {
for (auto & e : children()) {
e /= b;
}
TRACE("nla_cn_details", tout << *this << "\n";);
return *this;
}
if (is_var() || children().size() == 1) {
*this = scalar(T(1));
TRACE("nla_cn_details", tout << *this << "\n";);
return *this;
}
SASSERT(is_mul());
auto powers = b.get_powers_from_mul();
unsigned i = 0, k = 0;
for (; i < children().size(); i++, k++) {
auto & e = children()[i];
TRACE("nla_cn_details", tout << "e=" << e << ",i=" <<i<< ",k=" << k<< "\n";);
if (!e.is_var()) {
SASSERT(e.is_scalar());
if (i != k)
children()[k] = children()[i];
TRACE("nla_cn_details", tout << "continue\n";);
continue;
}
lpvar j = e.var();
auto it = powers.find(j);
if (it == powers.end()) {
if (i != k)
children()[k] = children()[i];
} else {
it->second --;
if (it->second == 0)
powers.erase(it);
k--;
}
TRACE("nla_cn_details", tout << *this << "\n";);
}
SASSERT(powers.size() == 0);
while(k ++ < i)
children().pop_back();
if (children().size() == 0)
*this = scalar(T(1));
TRACE("nla_cn_details", tout << *this << "\n";);
return *this;
}
nla_expr& operator+=(const nla_expr& b) {
if (is_sum()) {
if (b.is_sum()) {
for (auto& e: b.children())
add_child(e);
} else {
add_child(b);
}
return *this;
}
SASSERT(false); // not impl
return *this;
}
// we need a linear combination of at least two variables
bool sum_is_a_linear_term() const {
SASSERT(is_sum());
TRACE("nla_expr_details", tout << *this << "\n";);
unsigned number_of_non_scalars = 0;
for (auto & e : children()) {
int d = e.get_degree();
if (d == 0) continue;
if (d > 1) return false;
number_of_non_scalars++;
}
TRACE("nla_expr_details", tout << (number_of_non_scalars > 1?"linear":"non-linear") << "\n";);
return number_of_non_scalars > 1;
}
int get_degree() const {
switch (type()) {
case expr_type::SUM: {
int degree = 0;
for (auto & e : children()) {
degree = std::max(degree, e.get_degree());
}
return degree;
int degree = 0;
for (auto e : children()) {
degree = std::max(degree, e->get_degree());
}
case expr_type::MUL: {
int degree = 0;
for (auto & e : children()) {
degree += e.get_degree();
}
return degree;
}
case expr_type::VAR:
return 1;
case expr_type::SCALAR:
return 0;
case expr_type::UNDEF:
default:
UNREACHABLE();
break;
}
return 0;
}
return degree;
}
void add_child(nex* e) { m_children.push_back(e); }
};
/*
nla_expr operator/=(const nla_expr &a, const nla_expr& b) {
TRACE("nla_cn_details", tout << a <<" / " << b << "\n";);
if (b.is_var()) {
return a / b.var();
}
SASSERT(b.is_mul());
if (a.is_sum()) {
auto r = nex::sum();
for (auto & e : a.children()) {
r.add_child(e/b);
}
return r;
}
if (is_var()) {
return scalar(T(1));
return *this;
}
SASSERT(a.is_mul());
auto powers = b.get_powers_from_mul();
auto r=nex::mul();
for (unsigned i = 0; i < a.children().size(); i++, k++) {
auto & e = children()[i];
if (!e.is_var()) {
SASSERT(e.is_scalar());
r.add_child(e);
continue;
}
lpvar j = e.var();
auto it = powers.find(j);
if (it == powers.end()) {
r.add_child(e);
} else {
it->second --; // finish h
if (it->second == 0)
powers.erase(it);
}
}
return r;
}
*/
template <typename T>
nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
if (a.is_sum()) {
nla_expr<T> r(expr_type::SUM);
r.children() = a.children();
if (b.is_sum()) {
for (auto& e: b.children())
r.add_child(e);
} else {
r.add_child(b);
}
return r;
}
if (b.is_sum()) {
nla_expr<T> r(expr_type::SUM);
r.children() = b.children();
r.add_child(a);
return r;
}
return nla_expr<T>::sum(a, b);
inline const nex_sum* to_sum(const nex*a) {
SASSERT(a->is_sum());
return static_cast<const nex_sum*>(a);
}
template <typename T>
nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) {
if (a.is_scalar() && a.value() == T(1))
return b;
if (b.is_scalar() && b.value() == T(1))
return a;
if (a.is_mul()) {
nla_expr<T> r(expr_type::MUL);
r.children() = a.children();
if (b.is_mul()) {
for (auto& e: b.children())
r.add_child(e);
} else {
r.add_child(b);
}
return r;
}
if (b.is_mul()) {
nla_expr<T> r(expr_type::MUL);
r.children() = b.children();
r.add_child(a);
return r;
}
return nla_expr<T>::mul(a, b);
inline nex_sum* to_sum(nex * a) {
SASSERT(a->is_sum());
return static_cast<nex_sum*>(a);
}
template <typename T>
nla_expr<T> operator/(const nla_expr<T>& a, lpvar j) {
TRACE("nla_cn_details", tout << "a=" << a << ", v" << j << "\n";);
SASSERT((a.is_mul() && a.contains(j)) || (a.is_var() && a.var() == j));
if (a.is_var())
return nla_expr<T>::scalar(T(1));
nla_expr<T> b;
bool seenj = false;
for (const nla_expr<T>& c : a.children()) {
if (!seenj) {
if (c.contains(j)) {
if (!c.is_var())
b.add_child(c / j);
seenj = true;
continue;
}
}
b.add_child(c);
}
if (b.children().size() > 1) {
b.type() = expr_type::MUL;
} else if (b.children().size() == 1) {
auto t = b.children()[0];
b = t;
} else {
b = nla_expr<T>::scalar(T(1));
}
return b;
inline const nex_var* to_var(const nex*a) {
SASSERT(a->is_var());
return static_cast<const nex_var*>(a);
}
template <typename T>
std::ostream& operator<<(std::ostream& out, const nla_expr<T>& e ) {
inline const nex_mul* to_mul(const nex*a) {
SASSERT(a->is_mul());
return static_cast<const nex_mul*>(a);
}
inline nex_mul* to_mul(nex*a) {
SASSERT(a->is_mul());
return static_cast<nex_mul*>(a);
}
inline const nex_scalar * to_scalar(const nex* a) {
SASSERT(a->is_scalar());
return static_cast<const nex_scalar*>(a);
}
inline std::ostream& operator<<(std::ostream& out, const nex& e ) {
return e.print(out);
}