From 8cd9989dcf726c1492db8bb5caac6c345be32d73 Mon Sep 17 00:00:00 2001 From: Lev Nachmanson Date: Thu, 26 Sep 2019 17:18:45 -0700 Subject: [PATCH] process with nex simplifications Signed-off-by: Lev Nachmanson --- src/math/lp/cross_nested.h | 14 ++--- src/math/lp/nex.cpp | 101 +++++++++++++++++++++++++++++-------- src/math/lp/nex.h | 51 +++++-------------- src/math/lp/nla_grobner.h | 3 +- src/test/lp/lp.cpp | 31 +++++++----- 5 files changed, 122 insertions(+), 78 deletions(-) diff --git a/src/math/lp/cross_nested.h b/src/math/lp/cross_nested.h index adf592820..2b47bbdfb 100644 --- a/src/math/lp/cross_nested.h +++ b/src/math/lp/cross_nested.h @@ -36,12 +36,12 @@ class cross_nested { bool m_random_bit; nex_creator m_nex_creator; nex_lt m_lt; - + std::function m_mk_scalar; #ifdef Z3DEBUG nex* m_e_clone; #endif public: - + nex_creator& get_nex_creator() { return m_nex_creator; } cross_nested(std::function call_on_result, @@ -54,7 +54,9 @@ public: m_done(false), m_reported(0), m_nex_creator(lt), - m_lt(lt) {} + m_lt(lt), + m_mk_scalar([this]{return m_nex_creator.mk_scalar(rational(1));}) + {} void run(nex *e) { @@ -128,7 +130,7 @@ public: } nex* c_over_f = m_nex_creator.mk_div(*c, f); - to_sum(c_over_f)->simplify(&c_over_f, m_lt); + to_sum(c_over_f)->simplify(&c_over_f, m_lt, m_mk_scalar); nex_mul* cm; *c = cm = m_nex_creator.mk_mul(f, c_over_f); TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";); @@ -393,7 +395,7 @@ public: TRACE("nla_cn_details", tout << "a = " << *a << "\n";); SASSERT(a->children().size() >= 2 && m_b_split_vec.size()); nex* f; - a->simplify(&f, m_lt); + a->simplify(&f, m_lt, m_mk_scalar); if (m_b_split_vec.size() == 1) { b = m_b_split_vec[0]; @@ -488,7 +490,7 @@ public: a->children()[j] = normalize(a->children()[j]); } nex *r; - a->simplify(&r, m_lt); + a->simplify(&r, m_lt, m_mk_scalar); return r; } diff --git a/src/math/lp/nex.cpp b/src/math/lp/nex.cpp index 6d5abe4f9..866126870 100644 --- a/src/math/lp/nex.cpp +++ b/src/math/lp/nex.cpp @@ -22,15 +22,8 @@ namespace nla { -bool ignored_child(nex* e, expr_type t) { - switch(t) { - case expr_type::MUL: - return e->is_scalar() && to_scalar(e)->value().is_one(); - case expr_type::SUM: - return e->is_scalar() && to_scalar(e)->value().is_zero(); - default: return false; - } - return false; +bool is_zero_scalar(nex* e) { + return e->is_scalar() && to_scalar(e)->value().is_zero(); } void mul_to_powers(vector& children, nex_lt lt) { @@ -54,15 +47,50 @@ void mul_to_powers(vector& children, nex_lt lt) { }); } -void promote_children_of_sum(ptr_vector & children, nex_lt lt ) { +rational extract_coeff(const nex_mul* m) { + const nex* e = m->children().begin()->e(); + if (e->is_scalar()) + return to_scalar(e)->value(); + return rational(1); +} + + +bool sum_simplify_lt(const nex_mul* a, const nex_mul* b, const nex_lt& lt) { + NOT_IMPLEMENTED_YET(); +} + +// a + 3bc + 2bc => a + 5bc +void sort_join_sum(ptr_vector & children, nex_lt& lt, std::function mk_scalar) { + ptr_vector non_muls; + std::map> + m([lt](const nex_mul *a , const nex_mul *b) { return sum_simplify_lt(a, b, lt); }); + for (nex* e : children) { + SASSERT(e->is_simplified(lt)); + if (!e->is_mul()) { + non_muls.push_back(e); + } else { + nex_mul * em = to_mul(e); + rational r = extract_coeff(em); + auto it = m.find(em); + if (it == m.end()) { + m[em] = r; + } else { + it->second += r; + } + } + } + NOT_IMPLEMENTED_YET(); +} + +void simplify_children_of_sum(ptr_vector & children, nex_lt lt, std::function mk_scalar ) { ptr_vector to_promote; int skipped = 0; for(unsigned j = 0; j < children.size(); j++) { nex** e = &(children[j]); - (*e)->simplify(e, lt); + (*e)->simplify(e, lt, mk_scalar); if ((*e)->is_sum()) { to_promote.push_back(*e); - } else if (ignored_child(*e, expr_type::SUM)) { + } else if (is_zero_scalar(*e)) { skipped ++; continue; } else { @@ -77,13 +105,15 @@ void promote_children_of_sum(ptr_vector & children, nex_lt lt ) { for (nex *e : to_promote) { for (nex *ee : *(to_sum(e)->children_ptr())) { - if (!ignored_child(ee, expr_type::SUM)) + if (!is_zero_scalar(ee)) children.push_back(ee); } - } + } + + sort_join_sum(children, lt, mk_scalar); } -bool eat_scalar(nex_scalar *& r, nex_pow& p) { +bool eat_scalar_pow(nex_scalar *& r, nex_pow& p) { if (!p.e()->is_scalar()) return false; nex_scalar *pe = to_scalar(p.e()); @@ -96,18 +126,18 @@ bool eat_scalar(nex_scalar *& r, nex_pow& p) { return true; } -void simplify_children_of_mul(vector & children, nex_lt lt) { +void simplify_children_of_mul(vector & children, nex_lt lt, std::function mk_scalar) { nex_scalar* r = nullptr; TRACE("nla_cn_details", print_vector(children, tout);); vector to_promote; int skipped = 0; for(unsigned j = 0; j < children.size(); j++) { nex_pow& p = children[j]; - if (eat_scalar(r, p)) { + if (eat_scalar_pow(r, p)) { skipped++; continue; } - (p.e())->simplify(p.ee(), lt); + (p.e())->simplify(p.ee(), lt, mk_scalar ); if ((p.e())->is_mul()) { to_promote.push_back(p); } else { @@ -122,7 +152,7 @@ void simplify_children_of_mul(vector & children, nex_lt lt) { for (nex_pow & p : to_promote) { for (nex_pow& pp : to_mul(p.e())->children()) { - if (!eat_scalar(r, pp)) + if (!eat_scalar_pow(r, pp)) children.push_back(nex_pow(pp.e(), pp.pow() * p.pow())); } } @@ -133,7 +163,36 @@ void simplify_children_of_mul(vector & children, nex_lt lt) { mul_to_powers(children, lt); - TRACE("nla_cn_details", print_vector(children, tout);); - + TRACE("nla_cn_details", print_vector(children, tout);); } + +bool less_than_nex(const nex* a, const nex* b, lt_on_vars lt) { + int r = (int)(a->type()) - (int)(b->type()); + if (r) { + return r < 0; + } + SASSERT(a->type() == b->type()); + switch (a->type()) { + case expr_type::VAR: { + return lt(to_var(a)->var() , to_var(b)->var()); + } + case expr_type::SCALAR: { + return to_scalar(a)->value() < to_scalar(b)->value(); + } + case expr_type::MUL: { + NOT_IMPLEMENTED_YET(); + return false; // to_mul(a)->children() < to_mul(b)->children(); + } + case expr_type::SUM: { + NOT_IMPLEMENTED_YET(); + return false; //to_sum(a)->children() < to_sum(b)->children(); + } + default: + SASSERT(false); + return false; + } + + return false; +} + } diff --git a/src/math/lp/nex.h b/src/math/lp/nex.h index c874c7627..2d4fc57a8 100644 --- a/src/math/lp/nex.h +++ b/src/math/lp/nex.h @@ -52,6 +52,7 @@ inline std::ostream & operator<<(std::ostream& out, expr_type t) { class nex; bool less_than_nex_standard(const nex* a, const nex* b); +class nex_scalar; // This is the class of non-linear expressions class nex { public: @@ -78,8 +79,8 @@ public: virtual bool contains(lpvar j) const { return false; } virtual int get_degree() const = 0; // simplifies the expression and also assigns the address of "this" to *e - virtual void simplify(nex** e, nex_lt) { *e = this; } - void simplify(nex** e) { return simplify(e, less_than_nex_standard); } + virtual void simplify(nex** e, nex_lt, std::function) = 0; + void simplify(nex** e, std::function mk_scalar) { return simplify(e, less_than_nex_standard, mk_scalar); } virtual bool is_simplified(nex_lt) const { return true; } @@ -115,6 +116,7 @@ public: bool contains(lpvar j) const { return j == m_j; } int get_degree() const { return 1; } bool virtual is_linear() const { return true; } + void simplify(nex** e, nex_lt, std::function) {*e = this;} }; class nex_scalar : public nex { @@ -132,6 +134,7 @@ public: int get_degree() const { return 0; } bool is_linear() const { return true; } + void simplify(nex** e, nex_lt, std::function) {*e = this;} }; @@ -139,9 +142,9 @@ const nex_scalar * to_scalar(const nex* a); class nex_sum; const nex_sum* to_sum(const nex*a); -void promote_children_of_sum(ptr_vector & children, nex_lt); +void simplify_children_of_sum(ptr_vector & children, nex_lt, std::function); class nex_pow; -void simplify_children_of_mul(vector & children, nex_lt); +void simplify_children_of_mul(vector & children, nex_lt, std::function); class nex_pow { nex* m_e; @@ -238,12 +241,12 @@ public: return degree; } // the second argument is the comparison less than operator - void simplify(nex **e, nex_lt lt) { + void simplify(nex **e, nex_lt lt, std::function mk_scalar) { TRACE("nla_cn_details", tout << *this << "\n";); TRACE("nla_cn_details", tout << "**e = " << **e << "\n";); *e = this; TRACE("nla_cn_details", tout << *this << "\n";); - simplify_children_of_mul(m_children, lt); + simplify_children_of_mul(m_children, lt, mk_scalar); if (size() == 1 && m_children[0].pow() == 1) *e = m_children[0].e(); TRACE("nla_cn_details", tout << *this << "\n";); @@ -361,9 +364,9 @@ public: return out; } - void simplify(nex **e, nex_lt lt ) { + void simplify(nex **e, nex_lt lt, std::function mk_scalar) { *e = this; - promote_children_of_sum(m_children, lt); + simplify_children_of_sum(m_children, lt, mk_scalar); if (size() == 1) *e = m_children[0]; } @@ -444,37 +447,11 @@ inline std::ostream& operator<<(std::ostream& out, const nex& e ) { } -inline bool less_than_nex(const nex* a, const nex* b, lt_on_vars lt) { - int r = (int)(a->type()) - (int)(b->type()); - if (r) { - return r < 0; - } - // here a and b have the same type - switch (a->type()) { - case expr_type::VAR: { - return lt(to_var(a)->var() , to_var(b)->var()); - } - case expr_type::SCALAR: { - return to_scalar(a)->value() < to_scalar(b)->value(); - } - case expr_type::MUL: { - NOT_IMPLEMENTED_YET(); - return false; // to_mul(a)->children() < to_mul(b)->children(); - } - case expr_type::SUM: { - NOT_IMPLEMENTED_YET(); - return false; //to_sum(a)->children() < to_sum(b)->children(); - } - default: - SASSERT(false); - return false; - } - - return false; -} +bool less_than_nex(const nex* a, const nex* b, lt_on_vars lt); inline bool less_than_nex_standard(const nex* a, const nex* b) { - return less_than_nex(a, b, [](lpvar j, lpvar k) { return j < k; }); + lt_on_vars lt = [](lpvar j, lpvar k) { return j < k; }; + return less_than_nex(a, b, lt); } #if Z3DEBUG diff --git a/src/math/lp/nla_grobner.h b/src/math/lp/nla_grobner.h index c426ef13f..42b95bd7c 100644 --- a/src/math/lp/nla_grobner.h +++ b/src/math/lp/nla_grobner.h @@ -170,7 +170,8 @@ private: } bool less_than_on_expr(const nex* a, const nex* b) const { - return less_than_nex(a, b, [this](lpvar j, lpvar k) {return less_than_on_vars(j, k);}); + lt_on_vars lt = [this](lpvar j, lpvar k) {return less_than_on_vars(j, k);}; + return less_than_nex(a, b, lt); } diff --git a/src/test/lp/lp.cpp b/src/test/lp/lp.cpp index 1c5ab4a89..ace4f686d 100644 --- a/src/test/lp/lp.cpp +++ b/src/test/lp/lp.cpp @@ -85,26 +85,31 @@ void test_simplify() { ); enable_trace("nla_cn"); enable_trace("nla_cn_details"); - auto & creator = cn.get_nex_creator(); - nex_var* a = creator.mk_var(0); - nex_var* b = creator.mk_var(1); - nex_var* c = creator.mk_var(2); - auto m = creator.mk_mul(); m->add_child_in_power(c, 2); + nex_creator & r = cn.get_nex_creator(); + nex_var* a = r.mk_var(0); + nex_var* b = r.mk_var(1); + nex_var* c = r.mk_var(2); + auto m = r.mk_mul(); m->add_child_in_power(c, 2); TRACE("nla_cn", tout << "m = " << *m << "\n";); - auto n = creator.mk_mul(a); + auto n = r.mk_mul(a); n->add_child_in_power(b, 7); - n->add_child(creator.mk_scalar(rational(3))); - n->add_child_in_power(creator.mk_scalar(rational(4)), 2); - n->add_child(creator.mk_scalar(rational(1))); + n->add_child(r.mk_scalar(rational(3))); + n->add_child_in_power(r.mk_scalar(rational(4)), 2); + n->add_child(r.mk_scalar(rational(1))); TRACE("nla_cn", tout << "n = " << *n << "\n";); m->add_child_in_power(n, 3); - n->add_child_in_power(creator.mk_scalar(rational(1, 3)), 2); + n->add_child_in_power(r.mk_scalar(rational(1, 3)), 2); TRACE("nla_cn", tout << "m = " << *m << "\n";); - nex * e = creator.mk_sum(a, creator.mk_sum(b, m)); + nex * e = r.mk_sum(a, r.mk_sum(b, m)); TRACE("nla_cn", tout << "e = " << *e << "\n";); - e->simplify(&e); + std::function mks = [&r] {return r.mk_scalar(rational(1)); }; + e->simplify(&e, mks); TRACE("nla_cn", tout << "simplified e = " << *e << "\n";); + nex * l = r.mk_sum(e, r.mk_mul(r.mk_scalar(rational(3)), r.clone(e))); + TRACE("nla_cn", tout << "sum l = " << *l << "\n";); + l->simplify(&l, mks); + TRACE("nla_cn", tout << "simplified sum l = " << *l << "\n";); } void test_cn() { @@ -142,7 +147,7 @@ void test_cn() { nex* _6aad = cn.get_nex_creator().mk_mul(cn.get_nex_creator().mk_scalar(rational(6)), a, a, d); #ifdef Z3DEBUG nex * clone = cn.get_nex_creator().clone(cn.get_nex_creator().mk_sum(_6aad, abcd, aaccd, add, eae, eac, ed)); - clone->simplify(&clone); + clone->simplify(&clone,[&cn] {return cn.get_nex_creator().mk_scalar(rational(1));}); SASSERT(clone->is_simplified()); TRACE("nla_cn", tout << "clone = " << *clone << "\n";); #endif