From a844b88c32cd10240de8c449a614a3ef3443198d Mon Sep 17 00:00:00 2001 From: Lev Nachmanson Date: Mon, 19 Aug 2019 16:51:52 -0700 Subject: [PATCH] make sure that the returned cross nested form is equal to the original Signed-off-by: Lev Nachmanson --- src/math/lp/cross_nested.h | 36 +++++++++----- src/math/lp/nex.h | 99 +++++++++++++++++++++++++++++--------- src/test/lp/lp.cpp | 2 + 3 files changed, 101 insertions(+), 36 deletions(-) diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index babb455dd..4dc166fb9 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -270,7 +270,7 @@ public: } nex* c_over_f = mk_div(*c, f); - to_sum(c_over_f)->simplify(); + to_sum(c_over_f)->simplify(&c_over_f); *c = mk_mul(f, c_over_f); TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";); @@ -463,8 +463,7 @@ public: || (ce->is_var() && to_var(ce)->var() == j); } // all factors of j go to a, the rest to b - void pre_split(nex_sum * e, lpvar j, nex_sum* & a, nex* & b) { - + void pre_split(nex_sum * e, lpvar j, nex_sum*& a, nex*& b) { a = mk_sum(); m_b_split_vec.clear(); for (nex * ce: e->children()) { @@ -478,7 +477,8 @@ public: } TRACE("nla_cn_details", tout << "a = " << *a << "\n";); SASSERT(a->children().size() >= 2 && m_b_split_vec.size()); - a->simplify(); + nex* f; + a->simplify(&f); if (m_b_split_vec.size() == 1) { b = m_b_split_vec[0]; @@ -608,11 +608,13 @@ public: for (unsigned j = 0; j < a->size(); j ++) { a->children()[j] = normalize(a->children()[j]); } - a->simplify(); - return a; + nex *r; + a->simplify(&r); + return r; } nex * normalize_mul(nex_mul* a) { + TRACE("nla_cn", tout << *a << "\n";); int sum_j = -1; for (unsigned j = 0; j < a->size(); j ++) { a->children()[j] = normalize(a->children()[j]); @@ -620,28 +622,36 @@ public: sum_j = j; } - if (sum_j == -1) - return a; + if (sum_j == -1) { + nex * r; + a->simplify(&r); + SASSERT(r->is_simplified()); + return r; + } nex_sum *r = mk_sum(); nex_sum *as = to_sum(a->children()[sum_j]); for (unsigned k = 0; k < as->size(); k++) { nex_mul *b = mk_mul(as->children()[k]); - r->add_child(b); for (unsigned j = 0; j < a->size(); j ++) { if ((int)j != sum_j) b->add_child(a->children()[j]); } - b->simplify(); + nex *e; + b->simplify(&e); + r->add_child(e); } - TRACE("nla_cn", tout << *r << "\n";); - return normalize_sum(r); + TRACE("nla_cn", tout << *r << "\n";); + nex *rs = normalize_sum(r); + SASSERT(rs->is_simplified()); + return rs; + } nex * normalize(nex* a) { - if (a->is_simple()) + if (a->is_elementary()) return a; nex *r; if (a->is_mul()) { diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h index 67da8168f..9f588806d 100644 --- a/src/math/lp/nex.h +++ b/src/math/lp/nex.h @@ -49,7 +49,7 @@ public: virtual expr_type type() const = 0; virtual std::ostream& print(std::ostream&) const = 0; nex() {} - bool is_simple() const { + bool is_elementary() const { switch(type()) { case expr_type::SUM: case expr_type::MUL: @@ -67,7 +67,10 @@ public: virtual ~nex() {} virtual bool contains(lpvar j) const { return false; } virtual int get_degree() const = 0; - virtual void simplify() {} + virtual void simplify(nex** ) = 0; + virtual bool is_simplified() const { + return true; + } virtual const ptr_vector * children_ptr() const { UNREACHABLE(); return nullptr; @@ -103,6 +106,7 @@ public: bool contains(lpvar j) const { return j == m_j; } int get_degree() const { return 1; } + virtual void simplify(nex** e) { *e = this; } }; class nex_scalar : public nex { @@ -119,29 +123,48 @@ public: } int get_degree() const { return 0; } + virtual void simplify(nex** e) { *e = this; } }; +const nex_scalar * to_scalar(const nex* a); + +static bool ignored_child(nex* e, expr_type t) { + switch(t) { + case expr_type::MUL: + return e->is_scalar() && to_scalar(e)->value().is_one(); + case expr_type::SUM: + return e->is_scalar() && to_scalar(e)->value().is_zero(); + default: return false; + } + return false; +} + static void promote_children_by_type(ptr_vector * children, expr_type t) { ptr_vector to_promote; + int skipped = 0; for(unsigned j = 0; j < children->size(); j++) { - nex* e = (*children)[j]; - e->simplify(); - if (e->type() == t) { - to_promote.push_back(e); + nex** e = &(*children)[j]; + (*e)->simplify(e); + if ((*e)->type() == t) { + to_promote.push_back(*e); + } else if (ignored_child(*e, t)) { + skipped ++; + continue; } else { - unsigned offset = to_promote.size(); + unsigned offset = to_promote.size() + skipped; if (offset) { - (*children)[j - offset] = e; + (*children)[j - offset] = *e; } } } - - children->shrink(children->size() - to_promote.size()); + + children->shrink(children->size() - to_promote.size() - skipped); for (nex *e : to_promote) { for (nex *ee : *(e->children_ptr())) { - children->push_back(ee); + if (!ignored_child(ee, t)) + children->push_back(ee); } } } @@ -163,12 +186,12 @@ public: std::string s = v->str(); if (first) { first = false; - if (v->is_simple()) + if (v->is_elementary()) out << s; else out << "(" << s << ")"; } else { - if (v->is_simple()) { + if (v->is_elementary()) { if (s[0] == '-') { out << "*(" << s << ")"; } else { @@ -222,12 +245,29 @@ public: return degree; } - void simplify() { + void simplify(nex **e) { + *e = this; TRACE("nla_cn_details", tout << *this << "\n";); promote_children_by_type(&m_children, expr_type::MUL); + if (size() == 1) + *e = m_children[0]; TRACE("nla_cn_details", tout << *this << "\n";); + SASSERT((*e)->is_simplified()); } - #ifdef Z3DEBUG + + virtual bool is_simplified() const { + if (size() < 2) + return false; + for (nex * e : children()) { + if (e->is_mul()) + return false; + if (e->is_scalar() && to_scalar(e)->value().is_one()) + return false; + } + return true; + } + +#ifdef Z3DEBUG virtual void sort() { for (nex * c : m_children) { c->sort(); @@ -271,12 +311,12 @@ public: std::string s = v->str(); if (first) { first = false; - if (v->is_simple()) + if (v->is_elementary()) out << s; else out << "(" << s << ")"; } else { - if (v->is_simple()) { + if (v->is_elementary()) { if (s[0] == '-') { out << s; } else { @@ -290,8 +330,21 @@ public: return out; } - void simplify() { + void simplify(nex **e) { + *e = this; promote_children_by_type(&m_children, expr_type::SUM); + if (size() == 1) + *e = m_children[0]; + } + virtual bool is_simplified() const { + if (size() < 2) return false; + for (nex * e : children()) { + if (e->is_sum()) + return false; + if (e->is_scalar() && to_scalar(e)->value().is_zero()) + return false; + } + return true; } int get_degree() const { @@ -331,6 +384,11 @@ inline const nex_var* to_var(const nex*a) { return static_cast(a); } +inline const nex_scalar* to_scalar(const nex*a) { + SASSERT(a->is_scalar()); + return static_cast(a); +} + inline const nex_mul* to_mul(const nex*a) { SASSERT(a->is_mul()); return static_cast(a); @@ -341,11 +399,6 @@ inline nex_mul* to_mul(nex*a) { return static_cast(a); } -inline const nex_scalar * to_scalar(const nex* a) { - SASSERT(a->is_scalar()); - return static_cast(a); -} - inline std::ostream& operator<<(std::ostream& out, const nex& e ) { return e.print(out); } diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index f0b9ef763..0433232d5 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -102,8 +102,10 @@ void test_cn() { nex* eac = cn.mk_mul(e, a, c); nex* ed = cn.mk_mul(e, d); nex* _6aad = cn.mk_mul(cn.mk_scalar(rational(6)), a, a, d); +#ifdef Z3DEBUG nex * clone = cn.clone(cn.mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed)); TRACE("nla_cn", tout << "clone = " << *clone << "\n";); +#endif // test_cn_on_expr(cn.mk_sum(aad, abcd, aaccd, add, eae, eac, ed), cn); test_cn_on_expr(cn.mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed), cn); // TRACE("nla_cn", tout << "done\n";);