3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-29 11:55:51 +00:00

fix a bug in the recursion in cross_nested

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
Lev Nachmanson 2019-07-19 15:15:46 -07:00
parent cef9726f00
commit d5708b184a
6 changed files with 122 additions and 162 deletions

View file

@ -49,56 +49,11 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) {
// This class is needed in horner calculation with intervals
template <typename T>
class nla_expr {
class sorted_children {
std::vector<nla_expr> m_es;
// m_order will be sorted according to the non-decreasing order of m_es
svector<unsigned> m_order;
public:
const std::vector<nla_expr>& es() const { return m_es; }
std::vector<nla_expr>& es() { return m_es; }
void push_back(const nla_expr& e) {
SASSERT(m_es.size() == m_order.size());
m_order.push_back(m_es.size());
m_es.push_back(e);
}
const svector<unsigned>& order() const { return m_order; }
const nla_expr& back() const { return m_es.back(); }
nla_expr& back() { return m_es.back(); }
const nla_expr* begin() const { return m_es.begin(); }
const nla_expr* end() const { return m_es.end(); }
typename std::vector<nla_expr>::iterator begin() { return m_es.begin(); }
typename std::vector<nla_expr>::iterator end() { return m_es.end(); }
unsigned size() const { return m_es.size(); }
void sort() {
std::sort(m_order.begin(), m_order.end(), [this](unsigned i, unsigned j) { return m_es[i] < m_es[j]; });
}
bool operator<(const sorted_children& e) const {
return compare(e) < 0;
}
int compare(const sorted_children& e) const {
unsigned m = std::min(size(), e.size());
for (unsigned j = 0; j < m; j++) {
int r = m_es[m_order[j]].compare(e.m_es[e.m_order[j]]);
TRACE("nla_cn_details", tout << "r=" << r << "\n";);
if (r)
return r;
}
return static_cast<int>(size()) - static_cast<int>(e.size());
}
void reset_order() {
m_order.clear();
for( unsigned i = 0; i < m_es.size(); i++)
m_order.push_back(i);
}
};
// todo: use union
expr_type m_type;
lpvar m_j;
T m_v; // for the scalar
sorted_children m_children;
vector<nla_expr> m_children;
public:
bool is_sum() const { return m_type == expr_type::SUM; }
bool is_var() const { return m_type == expr_type::VAR; }
@ -108,16 +63,13 @@ public:
lpvar var() const { SASSERT(m_type == expr_type::VAR); return m_j; }
expr_type type() const { return m_type; }
expr_type& type() { return m_type; }
const std::vector<nla_expr>& children() const { return m_children.es(); }
std::vector<nla_expr>& children() { return m_children.es(); }
const sorted_children & s_children() const { return m_children; }
sorted_children & s_children() { return m_children; }
const vector<nla_expr>& children() const { return m_children; }
vector<nla_expr>& children() { return m_children; }
const T& value() const { SASSERT(m_type == expr_type::SCALAR); return m_v; }
std::string str() const { std::stringstream ss; ss << *this; return ss.str(); }
std::ostream & print_sum(std::ostream& out) const {
bool first = true;
for (unsigned j : m_children.order()) {
const nla_expr& v = m_children.es()[j];
for (const nla_expr& v : m_children) {
std::string s = v.str();
if (first) {
first = false;
@ -139,14 +91,6 @@ public:
}
return out;
}
void sort () {
if (is_sum() || is_mul()) {
for (auto & e : m_children.es()) {
e.sort();
}
m_children.sort();
}
}
void simplify() {
if (is_simple()) return;
@ -158,7 +102,7 @@ public:
}
if (has_sum) {
nla_expr n(expr_type::SUM);
for (auto &e : m_children.es()) {
for (auto &e : m_children) {
n += e;
}
*this = n;
@ -171,19 +115,18 @@ public:
}
if (has_mul) {
nla_expr n(expr_type::MUL);
for (auto &e : m_children.es()) {
for (auto &e : m_children) {
n *= e;
}
*this = n;
}
TRACE("nla_cn", tout << "simplified " << *this << "\n";);
TRACE("nla_cn_details", tout << "simplified " << *this << "\n";);
}
}
std::ostream & print_mul(std::ostream& out) const {
bool first = true;
for (unsigned j : m_children.order()) {
const nla_expr& v = m_children.es()[j];
for (const nla_expr& v : m_children) {
std::string s = v.str();
if (first) {
first = false;
@ -212,7 +155,7 @@ public:
case expr_type::MUL:
return print_mul(out);
case expr_type::VAR:
out << 'v' << m_j;
out << (char)('a'+ m_j);
return out;
case expr_type::SCALAR:
out << m_v;
@ -320,48 +263,6 @@ public:
return false;
}
int compare(const nla_expr& e) const {
TRACE("nla_cn_details", tout << "this="<<*this<<", e=" << e << "\n";);
if (type() != e.type())
return (int)type() - (int)(e.type());
switch(m_type) {
case expr_type::SUM:
case expr_type::MUL:
return m_children.compare(e.m_children);
case expr_type::VAR:
return static_cast<int>(m_j) - static_cast<int>(e.m_j);
case expr_type::SCALAR:
return m_v < e.m_v? -1 : (m_v == e.m_v? 0 : 1);
default:
SASSERT(false);
return 0;
}
}
bool operator<(const nla_expr& e) const {
TRACE("nla_cn_details", tout << "this=" << *this << ", e=" << e << "\n";);
if (type() != (e.type()))
return (int)type() < (int)(e.type());
SASSERT(type() == (e.type()));
switch(m_type) {
case expr_type::SUM:
case expr_type::MUL:
return m_children < e.m_children;
case expr_type::VAR:
return m_j < e.m_j;
case expr_type::SCALAR:
return m_v < e.m_v;
default:
SASSERT(false);
return false;
}
}
nla_expr& operator*=(const nla_expr& b) {
if (is_mul()) {
if (b.is_mul()) {
@ -391,6 +292,7 @@ public:
it->second++;
}
}
TRACE("nla_cn_details", tout << "powers of " << *this << "\n"; print_vector(r, tout)<< "\n";);
return r;
}
@ -436,8 +338,6 @@ public:
while(k ++ < i)
children().pop_back();
s_children().reset_order();
return *this;
}
@ -458,11 +358,54 @@ public:
}
};
/*
nla_expr operator/=(const nla_expr &a, const nla_expr& b) {
TRACE("nla_cn_details", tout << a <<" / " << b << "\n";);
if (b.is_var()) {
return a / b.var();
}
SASSERT(b.is_mul());
if (a.is_sum()) {
auto r = nex::sum();
for (auto & e : a.children()) {
r.add_child(e/b);
}
return r;
}
if (is_var()) {
return scalar(T(1));
return *this;
}
SASSERT(a.is_mul());
auto powers = b.get_powers_from_mul();
auto r=nex::mul();
for (unsigned i = 0; i < a.children().size(); i++, k++) {
auto & e = children()[i];
if (!e.is_var()) {
SASSERT(e.is_scalar());
r.add_child(e);
continue;
}
lpvar j = e.var();
auto it = powers.find(j);
if (it == powers.end()) {
r.add_child(e);
} else {
it->second --; // finish h
if (it->second == 0)
powers.erase(it);
}
}
return r;
}
*/
template <typename T>
nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
if (a.is_sum()) {
nla_expr<T> r(expr_type::SUM);
r.s_children() = a.s_children();
r.children() = a.children();
if (b.is_sum()) {
for (auto& e: b.children())
r.add_child(e);
@ -473,7 +416,7 @@ nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
}
if (b.is_sum()) {
nla_expr<T> r(expr_type::SUM);
r.s_children() = b.s_children();
r.children() = b.children();
r.add_child(a);
return r;
}
@ -482,9 +425,13 @@ nla_expr<T> operator+(const nla_expr<T>& a, const nla_expr<T>& b) {
template <typename T>
nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) {
if (a.is_scalar() && a.value() == T(1))
return b;
if (b.is_scalar() && b.value() == T(1))
return a;
if (a.is_mul()) {
nla_expr<T> r(expr_type::MUL);
r.s_children() = a.s_children();
r.children() = a.children();
if (b.is_mul()) {
for (auto& e: b.children())
r.add_child(e);
@ -495,7 +442,7 @@ nla_expr<T> operator*(const nla_expr<T>& a, const nla_expr<T>& b) {
}
if (b.is_mul()) {
nla_expr<T> r(expr_type::MUL);
r.s_children() = b.s_children();
r.children() = b.children();
r.add_child(a);
return r;
}